refactor: use async-aware RwLocks and Mutexes where possible

squashed from https://gitlab.com/famedly/conduit/-/merge_requests/595

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
Matthias Ahouansou 2024-03-05 20:52:16 -05:00 committed by June
parent 46b543eebe
commit 4ec2d3ecb5
20 changed files with 174 additions and 194 deletions

View file

@ -4,7 +4,6 @@ type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
use std::{
collections::{hash_map, HashSet},
pin::Pin,
sync::RwLockWriteGuard,
time::{Duration, Instant, SystemTime},
};
@ -33,12 +32,12 @@ use ruma::{
OwnedServerSigningKeyId, RoomId, RoomVersionId, ServerName,
};
use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::Semaphore;
use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore};
use tracing::{debug, error, info, trace, warn};
use super::state_compressor::CompressedStateEvent;
use crate::{
service::{pdu, Arc, BTreeMap, HashMap, Result, RwLock},
service::{pdu, Arc, BTreeMap, HashMap, Result},
services, Error, PduEvent,
};
@ -168,7 +167,7 @@ impl Service {
));
}
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*prev_id) {
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*prev_id) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -183,7 +182,7 @@ impl Service {
if errors >= 5 {
// Timeout other events
match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) {
match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
@ -205,7 +204,7 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time));
if let Err(e) =
@ -213,7 +212,7 @@ impl Service {
{
errors += 1;
warn!("Prev event {} failed: {}", prev_id, e);
match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) {
match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
@ -223,7 +222,7 @@ impl Service {
}
}
let elapsed = start_time.elapsed();
services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned());
services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned());
debug!(
"Handling prev event {} took {}m{}s",
prev_id,
@ -240,14 +239,14 @@ impl Service {
.globals
.roomid_federationhandletime
.write()
.unwrap()
.await
.insert(room_id.to_owned(), (event_id.to_owned(), start_time));
let r = services()
.rooms
.event_handler
.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map)
.await;
services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned());
services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned());
r
}
@ -275,11 +274,8 @@ impl Service {
let room_version_id = &create_event_content.room_version;
let room_version = RoomVersion::new(room_version_id).expect("room version is supported");
let mut val = match ruma::signatures::verify_event(
&pub_key_map.read().expect("RwLock is poisoned."),
&value,
room_version_id,
) {
let guard = pub_key_map.read().await;
let mut val = match ruma::signatures::verify_event(&guard, &value, room_version_id) {
Err(e) => {
// Drop
warn!("Dropping bad event {}: {}", event_id, e,);
@ -306,6 +302,8 @@ impl Service {
Ok(ruma::signatures::Verified::All) => value,
};
drop(guard);
// Now that we have checked the signature and hashes we can add the eventID and
// convert to our PduEvent type
val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
@ -580,10 +578,13 @@ impl Service {
{
Ok(res) => {
debug!("Fetching state events at event.");
let collect = res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::<Vec<_>>();
let state_vec = self
.fetch_and_handle_outliers(
origin,
&res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::<Vec<_>>(),
&collect,
create_event,
room_id,
room_version_id,
@ -680,7 +681,7 @@ impl Service {
// We start looking at current room state now, so lets lock the room
let mutex_state =
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
let state_lock = mutex_state.lock().await;
// Now we calculate the set of extremities this room has after the incoming
@ -876,11 +877,13 @@ impl Service {
room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> AsyncRecursiveCanonicalJsonVec<'a> {
Box::pin(async move {
let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
let back_off = |id| async {
match services().globals.bad_event_ratelimiter.write().await.entry(id) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
};
let mut events_with_auth_events = vec![];
@ -902,8 +905,7 @@ impl Service {
let mut events_all = HashSet::new();
let mut i = 0;
while let Some(next_id) = todo_auth_events.pop() {
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*next_id)
{
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*next_id) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1009,9 +1011,7 @@ impl Service {
pdus.push((local_pdu, None));
}
for (next_id, value) in events_in_reverse_order.iter().rev() {
if let Some((time, tries)) =
services().globals.bad_event_ratelimiter.read().unwrap().get(&**next_id)
{
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&**next_id) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1205,10 +1205,7 @@ impl Service {
while let Some(fetch_res) = server_keys.next().await {
match fetch_res {
Ok((signature_server, keys)) => {
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(signature_server.clone(), keys);
pub_key_map.write().await.insert(signature_server.clone(), keys);
},
Err((signature_server, e)) => {
warn!("Failed to fetch keys for {}: {:?}", signature_server, e);
@ -1222,7 +1219,7 @@ impl Service {
// Gets a list of servers for which we don't have the signing key yet. We go
// over the PDUs and either cache the key or add it to the list that needs to be
// retrieved.
fn get_server_keys_from_cache(
async fn get_server_keys_from_cache(
&self, pdu: &RawJsonValue,
servers: &mut BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
room_version: &RoomVersionId,
@ -1239,7 +1236,7 @@ impl Service {
);
let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids");
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(event_id) {
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(event_id) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1310,7 +1307,7 @@ impl Service {
.await
{
debug!("Got signing keys: {:?}", keys);
let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?;
let mut pkm = pub_key_map.write().await;
for k in keys.server_keys {
let k = match k.deserialize() {
Ok(key) => key,
@ -1365,10 +1362,7 @@ impl Service {
.into_iter()
.map(|(k, v)| (k.to_string(), v.key))
.collect();
pub_key_map
.write()
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
.insert(origin.to_string(), result);
pub_key_map.write().await.insert(origin.to_string(), result);
}
}
debug!("Done handling Future result");
@ -1384,15 +1378,15 @@ impl Service {
let mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>> = BTreeMap::new();
{
let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?;
let mut pkm = pub_key_map.write().await;
// Try to fetch keys, failure is okay
// Servers we couldn't find in the cache will be added to `servers`
for pdu in &event.room_state.state {
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await;
}
for pdu in &event.room_state.auth_chain {
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await;
}
drop(pkm);
@ -1491,18 +1485,13 @@ impl Service {
) -> Result<BTreeMap<String, Base64>> {
let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
let permit = services()
.globals
.servername_ratelimiter
.read()
.unwrap()
.get(origin)
.map(|s| Arc::clone(s).acquire_owned());
let permit =
services().globals.servername_ratelimiter.read().await.get(origin).map(|s| Arc::clone(s).acquire_owned());
let permit = match permit {
Some(p) => p,
None => {
let mut write = services().globals.servername_ratelimiter.write().unwrap();
let mut write = services().globals.servername_ratelimiter.write().await;
let s = Arc::clone(write.entry(origin.to_owned()).or_insert_with(|| Arc::new(Semaphore::new(1))));
s.acquire_owned()
@ -1510,14 +1499,16 @@ impl Service {
}
.await;
let back_off = |id| match services().globals.bad_signature_ratelimiter.write().unwrap().entry(id) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
let back_off = |id| async {
match services().globals.bad_signature_ratelimiter.write().await.entry(id) {
hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1));
},
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
}
};
if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().unwrap().get(&signature_ids) {
if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().await.get(&signature_ids) {
// Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
@ -1591,7 +1582,7 @@ impl Service {
drop(permit);
back_off(signature_ids);
back_off(signature_ids).await;
warn!("Failed to find public key for server: {}", origin);
Err(Error::BadServerResponse("Failed to find public key for server"))