refactor: use async-aware RwLocks and Mutexes where possible

squashed from https://gitlab.com/famedly/conduit/-/merge_requests/595

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
Matthias Ahouansou 2024-03-05 20:52:16 -05:00 committed by June
parent 46b543eebe
commit 4ec2d3ecb5
20 changed files with 174 additions and 194 deletions

View file

@ -1,6 +1,6 @@
use std::{
collections::{hash_map::Entry, BTreeMap, HashMap, HashSet},
sync::{Arc, RwLock},
sync::Arc,
time::{Duration, Instant},
};
@ -29,6 +29,7 @@ use ruma::{
OwnedUserId, RoomId, RoomVersionId, UserId,
};
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use super::get_alias_helper;
@ -242,7 +243,7 @@ pub async fn kick_user_route(body: Ruma<kick_user::v3::Request>) -> Result<kick_
event.reason = body.reason.clone();
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
services()
@ -303,7 +304,7 @@ pub async fn ban_user_route(body: Ruma<ban_user::v3::Request>) -> Result<ban_use
)?;
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
services()
@ -349,7 +350,7 @@ pub async fn unban_user_route(body: Ruma<unban_user::v3::Request>) -> Result<unb
event.reason = body.reason.clone();
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(body.room_id.clone()).or_default());
let state_lock = mutex_state.lock().await;
services()
@ -480,7 +481,7 @@ async fn join_room_by_id_helper(
let sender_user = sender_user.expect("user is authenticated");
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
// Ask a remote server if we are not participating in this room
@ -680,7 +681,7 @@ async fn join_room_by_id_helper(
.iter()
.map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map))
{
let (event_id, value) = match result {
let (event_id, value) = match result.await {
Ok(t) => t,
Err(_) => continue,
};
@ -705,7 +706,7 @@ async fn join_room_by_id_helper(
.iter()
.map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map))
{
let (event_id, value) = match result {
let (event_id, value) = match result.await {
Ok(t) => t,
Err(_) => continue,
};
@ -1048,7 +1049,7 @@ async fn make_join_request(
make_join_response_and_server
}
fn validate_and_add_event_id(
async fn validate_and_add_event_id(
pdu: &RawJsonValue, room_version: &RoomVersionId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<(OwnedEventId, CanonicalJsonObject)> {
let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
@ -1061,14 +1062,16 @@ fn validate_and_add_event_id(
))
.expect("ruma's reference hashes are valid event ids");
let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) {
Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
let back_off = |id| async {
match services().globals.bad_event_ratelimiter.write().await.entry(id) {
Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
};
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&event_id) {
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&event_id) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1081,13 +1084,9 @@ fn validate_and_add_event_id(
}
}
if let Err(e) = ruma::signatures::verify_event(
&*pub_key_map.read().map_err(|_| Error::bad_database("RwLock is poisoned."))?,
&value,
room_version,
) {
if let Err(e) = ruma::signatures::verify_event(&*pub_key_map.read().await, &value, room_version) {
warn!("Event {} failed verification {:?} {}", event_id, pdu, e);
back_off(event_id);
back_off(event_id).await;
return Err(Error::BadServerResponse("Event failed verification."));
}
@ -1109,9 +1108,8 @@ pub(crate) async fn invite_helper(
if user_id.server_name() != services().globals.server_name() {
let (pdu, pdu_json, invite_room_state) = {
let mutex_state = Arc::clone(
services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default(),
);
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
let content = to_raw_value(&RoomMemberEventContent {
@ -1229,7 +1227,7 @@ pub(crate) async fn invite_helper(
}
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
services()
@ -1314,7 +1312,7 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option<Strin
.await?;
} else {
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
let member_event =