refactor: use async-aware RwLocks and Mutexes where possible
squashed from https://gitlab.com/famedly/conduit/-/merge_requests/595 Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
parent
46b543eebe
commit
4ec2d3ecb5
20 changed files with 174 additions and 194 deletions
|
@ -4,7 +4,6 @@ type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
|
|||
use std::{
|
||||
collections::{hash_map, HashSet},
|
||||
pin::Pin,
|
||||
sync::RwLockWriteGuard,
|
||||
time::{Duration, Instant, SystemTime},
|
||||
};
|
||||
|
||||
|
@ -33,12 +32,12 @@ use ruma::{
|
|||
OwnedServerSigningKeyId, RoomId, RoomVersionId, ServerName,
|
||||
};
|
||||
use serde_json::value::RawValue as RawJsonValue;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore};
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
use super::state_compressor::CompressedStateEvent;
|
||||
use crate::{
|
||||
service::{pdu, Arc, BTreeMap, HashMap, Result, RwLock},
|
||||
service::{pdu, Arc, BTreeMap, HashMap, Result},
|
||||
services, Error, PduEvent,
|
||||
};
|
||||
|
||||
|
@ -168,7 +167,7 @@ impl Service {
|
|||
));
|
||||
}
|
||||
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*prev_id) {
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*prev_id) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -183,7 +182,7 @@ impl Service {
|
|||
|
||||
if errors >= 5 {
|
||||
// Timeout other events
|
||||
match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) {
|
||||
match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
|
@ -205,7 +204,7 @@ impl Service {
|
|||
.globals
|
||||
.roomid_federationhandletime
|
||||
.write()
|
||||
.unwrap()
|
||||
.await
|
||||
.insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time));
|
||||
|
||||
if let Err(e) =
|
||||
|
@ -213,7 +212,7 @@ impl Service {
|
|||
{
|
||||
errors += 1;
|
||||
warn!("Prev event {} failed: {}", prev_id, e);
|
||||
match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) {
|
||||
match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
|
@ -223,7 +222,7 @@ impl Service {
|
|||
}
|
||||
}
|
||||
let elapsed = start_time.elapsed();
|
||||
services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned());
|
||||
services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned());
|
||||
debug!(
|
||||
"Handling prev event {} took {}m{}s",
|
||||
prev_id,
|
||||
|
@ -240,14 +239,14 @@ impl Service {
|
|||
.globals
|
||||
.roomid_federationhandletime
|
||||
.write()
|
||||
.unwrap()
|
||||
.await
|
||||
.insert(room_id.to_owned(), (event_id.to_owned(), start_time));
|
||||
let r = services()
|
||||
.rooms
|
||||
.event_handler
|
||||
.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map)
|
||||
.await;
|
||||
services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned());
|
||||
services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned());
|
||||
|
||||
r
|
||||
}
|
||||
|
@ -275,11 +274,8 @@ impl Service {
|
|||
let room_version_id = &create_event_content.room_version;
|
||||
let room_version = RoomVersion::new(room_version_id).expect("room version is supported");
|
||||
|
||||
let mut val = match ruma::signatures::verify_event(
|
||||
&pub_key_map.read().expect("RwLock is poisoned."),
|
||||
&value,
|
||||
room_version_id,
|
||||
) {
|
||||
let guard = pub_key_map.read().await;
|
||||
let mut val = match ruma::signatures::verify_event(&guard, &value, room_version_id) {
|
||||
Err(e) => {
|
||||
// Drop
|
||||
warn!("Dropping bad event {}: {}", event_id, e,);
|
||||
|
@ -306,6 +302,8 @@ impl Service {
|
|||
Ok(ruma::signatures::Verified::All) => value,
|
||||
};
|
||||
|
||||
drop(guard);
|
||||
|
||||
// Now that we have checked the signature and hashes we can add the eventID and
|
||||
// convert to our PduEvent type
|
||||
val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
|
||||
|
@ -580,10 +578,13 @@ impl Service {
|
|||
{
|
||||
Ok(res) => {
|
||||
debug!("Fetching state events at event.");
|
||||
|
||||
let collect = res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::<Vec<_>>();
|
||||
|
||||
let state_vec = self
|
||||
.fetch_and_handle_outliers(
|
||||
origin,
|
||||
&res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::<Vec<_>>(),
|
||||
&collect,
|
||||
create_event,
|
||||
room_id,
|
||||
room_version_id,
|
||||
|
@ -680,7 +681,7 @@ impl Service {
|
|||
|
||||
// We start looking at current room state now, so lets lock the room
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
// Now we calculate the set of extremities this room has after the incoming
|
||||
|
@ -876,11 +877,13 @@ impl Service {
|
|||
room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
|
||||
) -> AsyncRecursiveCanonicalJsonVec<'a> {
|
||||
Box::pin(async move {
|
||||
let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
let back_off = |id| async {
|
||||
match services().globals.bad_event_ratelimiter.write().await.entry(id) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
}
|
||||
};
|
||||
|
||||
let mut events_with_auth_events = vec![];
|
||||
|
@ -902,8 +905,7 @@ impl Service {
|
|||
let mut events_all = HashSet::new();
|
||||
let mut i = 0;
|
||||
while let Some(next_id) = todo_auth_events.pop() {
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*next_id)
|
||||
{
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*next_id) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -1009,9 +1011,7 @@ impl Service {
|
|||
pdus.push((local_pdu, None));
|
||||
}
|
||||
for (next_id, value) in events_in_reverse_order.iter().rev() {
|
||||
if let Some((time, tries)) =
|
||||
services().globals.bad_event_ratelimiter.read().unwrap().get(&**next_id)
|
||||
{
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&**next_id) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -1205,10 +1205,7 @@ impl Service {
|
|||
while let Some(fetch_res) = server_keys.next().await {
|
||||
match fetch_res {
|
||||
Ok((signature_server, keys)) => {
|
||||
pub_key_map
|
||||
.write()
|
||||
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
|
||||
.insert(signature_server.clone(), keys);
|
||||
pub_key_map.write().await.insert(signature_server.clone(), keys);
|
||||
},
|
||||
Err((signature_server, e)) => {
|
||||
warn!("Failed to fetch keys for {}: {:?}", signature_server, e);
|
||||
|
@ -1222,7 +1219,7 @@ impl Service {
|
|||
// Gets a list of servers for which we don't have the signing key yet. We go
|
||||
// over the PDUs and either cache the key or add it to the list that needs to be
|
||||
// retrieved.
|
||||
fn get_server_keys_from_cache(
|
||||
async fn get_server_keys_from_cache(
|
||||
&self, pdu: &RawJsonValue,
|
||||
servers: &mut BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
|
||||
room_version: &RoomVersionId,
|
||||
|
@ -1239,7 +1236,7 @@ impl Service {
|
|||
);
|
||||
let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids");
|
||||
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(event_id) {
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(event_id) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -1310,7 +1307,7 @@ impl Service {
|
|||
.await
|
||||
{
|
||||
debug!("Got signing keys: {:?}", keys);
|
||||
let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?;
|
||||
let mut pkm = pub_key_map.write().await;
|
||||
for k in keys.server_keys {
|
||||
let k = match k.deserialize() {
|
||||
Ok(key) => key,
|
||||
|
@ -1365,10 +1362,7 @@ impl Service {
|
|||
.into_iter()
|
||||
.map(|(k, v)| (k.to_string(), v.key))
|
||||
.collect();
|
||||
pub_key_map
|
||||
.write()
|
||||
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
|
||||
.insert(origin.to_string(), result);
|
||||
pub_key_map.write().await.insert(origin.to_string(), result);
|
||||
}
|
||||
}
|
||||
debug!("Done handling Future result");
|
||||
|
@ -1384,15 +1378,15 @@ impl Service {
|
|||
let mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>> = BTreeMap::new();
|
||||
|
||||
{
|
||||
let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?;
|
||||
let mut pkm = pub_key_map.write().await;
|
||||
|
||||
// Try to fetch keys, failure is okay
|
||||
// Servers we couldn't find in the cache will be added to `servers`
|
||||
for pdu in &event.room_state.state {
|
||||
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
|
||||
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await;
|
||||
}
|
||||
for pdu in &event.room_state.auth_chain {
|
||||
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
|
||||
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await;
|
||||
}
|
||||
|
||||
drop(pkm);
|
||||
|
@ -1491,18 +1485,13 @@ impl Service {
|
|||
) -> Result<BTreeMap<String, Base64>> {
|
||||
let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
|
||||
|
||||
let permit = services()
|
||||
.globals
|
||||
.servername_ratelimiter
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(origin)
|
||||
.map(|s| Arc::clone(s).acquire_owned());
|
||||
let permit =
|
||||
services().globals.servername_ratelimiter.read().await.get(origin).map(|s| Arc::clone(s).acquire_owned());
|
||||
|
||||
let permit = match permit {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
let mut write = services().globals.servername_ratelimiter.write().unwrap();
|
||||
let mut write = services().globals.servername_ratelimiter.write().await;
|
||||
let s = Arc::clone(write.entry(origin.to_owned()).or_insert_with(|| Arc::new(Semaphore::new(1))));
|
||||
|
||||
s.acquire_owned()
|
||||
|
@ -1510,14 +1499,16 @@ impl Service {
|
|||
}
|
||||
.await;
|
||||
|
||||
let back_off = |id| match services().globals.bad_signature_ratelimiter.write().unwrap().entry(id) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
let back_off = |id| async {
|
||||
match services().globals.bad_signature_ratelimiter.write().await.entry(id) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
}
|
||||
};
|
||||
|
||||
if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().unwrap().get(&signature_ids) {
|
||||
if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().await.get(&signature_ids) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -1591,7 +1582,7 @@ impl Service {
|
|||
|
||||
drop(permit);
|
||||
|
||||
back_off(signature_ids);
|
||||
back_off(signature_ids).await;
|
||||
|
||||
warn!("Failed to find public key for server: {}", origin);
|
||||
Err(Error::BadServerResponse("Failed to find public key for server"))
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue