refactor: use async-aware RwLocks and Mutexes where possible
squashed from https://gitlab.com/famedly/conduit/-/merge_requests/595 Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
parent
46b543eebe
commit
4ec2d3ecb5
20 changed files with 174 additions and 194 deletions
|
@ -1,9 +1,4 @@
|
|||
use std::{
|
||||
collections::BTreeMap,
|
||||
fmt::Write as _,
|
||||
sync::{Arc, RwLock},
|
||||
time::Instant,
|
||||
};
|
||||
use std::{collections::BTreeMap, fmt::Write as _, sync::Arc, time::Instant};
|
||||
|
||||
use clap::{Parser, Subcommand};
|
||||
use regex::Regex;
|
||||
|
@ -29,7 +24,7 @@ use ruma::{
|
|||
ServerName, UserId,
|
||||
};
|
||||
use serde_json::value::to_raw_value;
|
||||
use tokio::sync::{mpsc, Mutex};
|
||||
use tokio::sync::{mpsc, Mutex, RwLock};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::pdu::PduBuilder;
|
||||
|
@ -473,7 +468,7 @@ impl Service {
|
|||
services().globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.await
|
||||
.entry(conduit_room.clone())
|
||||
.or_default(),
|
||||
);
|
||||
|
@ -1773,7 +1768,7 @@ impl Service {
|
|||
RoomMessageEventContent::text_plain("Room enabled.")
|
||||
},
|
||||
FederationCommand::IncomingFederation => {
|
||||
let map = services().globals.roomid_federationhandletime.read().unwrap();
|
||||
let map = services().globals.roomid_federationhandletime.read().await;
|
||||
let mut msg: String = format!("Handling {} incoming pdus:\n", map.len());
|
||||
|
||||
for (r, (e, i)) in map.iter() {
|
||||
|
@ -1818,7 +1813,7 @@ impl Service {
|
|||
.fetch_required_signing_keys([&value], &pub_key_map)
|
||||
.await?;
|
||||
|
||||
let pub_key_map = pub_key_map.read().unwrap();
|
||||
let pub_key_map = pub_key_map.read().await;
|
||||
match ruma::signatures::verify_json(&pub_key_map, &value) {
|
||||
Ok(_) => RoomMessageEventContent::text_plain("Signature correct"),
|
||||
Err(e) => RoomMessageEventContent::text_plain(format!(
|
||||
|
@ -1841,7 +1836,7 @@ impl Service {
|
|||
RoomMessageEventContent::text_plain(format!("{}", services().globals.config))
|
||||
},
|
||||
ServerCommand::MemoryUsage => {
|
||||
let response1 = services().memory_usage();
|
||||
let response1 = services().memory_usage().await;
|
||||
let response2 = services().globals.db.memory_usage();
|
||||
|
||||
RoomMessageEventContent::text_plain(format!("Services:\n{response1}\n\nDatabase:\n{response2}"))
|
||||
|
@ -1856,7 +1851,7 @@ impl Service {
|
|||
ServerCommand::ClearServiceCaches {
|
||||
amount,
|
||||
} => {
|
||||
services().clear_caches(amount);
|
||||
services().clear_caches(amount).await;
|
||||
|
||||
RoomMessageEventContent::text_plain("Done.")
|
||||
},
|
||||
|
@ -2047,7 +2042,7 @@ impl Service {
|
|||
services().rooms.short.get_or_create_shortroomid(&room_id)?;
|
||||
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
// Create a user for the server
|
||||
|
@ -2291,7 +2286,7 @@ impl Service {
|
|||
let room_id = services().rooms.alias.resolve_local_alias(&admin_room_alias)?.expect("Admin room must exist");
|
||||
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
// Use the server user to grant the new admin's power level
|
||||
|
|
|
@ -8,7 +8,7 @@ use std::{
|
|||
path::PathBuf,
|
||||
sync::{
|
||||
atomic::{self, AtomicBool},
|
||||
Arc, Mutex, RwLock,
|
||||
Arc, RwLock as StdRwLock,
|
||||
},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
@ -33,7 +33,7 @@ use ruma::{
|
|||
RoomVersionId, ServerName, UserId,
|
||||
};
|
||||
use sha2::Digest;
|
||||
use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore};
|
||||
use tokio::sync::{broadcast, watch::Receiver, Mutex, RwLock, Semaphore};
|
||||
use tracing::{error, info};
|
||||
use trust_dns_resolver::TokioAsyncResolver;
|
||||
|
||||
|
@ -53,7 +53,7 @@ pub struct Service<'a> {
|
|||
pub db: &'static dyn Data,
|
||||
|
||||
pub actual_destination_cache: Arc<RwLock<WellKnownMap>>, // actual_destination, host
|
||||
pub tls_name_override: Arc<RwLock<TlsNameMap>>,
|
||||
pub tls_name_override: Arc<StdRwLock<TlsNameMap>>,
|
||||
pub config: Config,
|
||||
keypair: Arc<ruma::signatures::Ed25519KeyPair>,
|
||||
dns_resolver: TokioAsyncResolver,
|
||||
|
@ -68,9 +68,9 @@ pub struct Service<'a> {
|
|||
pub bad_query_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, RateLimitState>>>,
|
||||
pub servername_ratelimiter: Arc<RwLock<HashMap<OwnedServerName, Arc<Semaphore>>>>,
|
||||
pub sync_receivers: RwLock<HashMap<(OwnedUserId, OwnedDeviceId), SyncHandle>>,
|
||||
pub roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>,
|
||||
pub roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>,
|
||||
pub roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<TokioMutex<()>>>>, // this lock will be held longer
|
||||
pub roomid_mutex_insert: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
|
||||
pub roomid_mutex_state: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>,
|
||||
pub roomid_mutex_federation: RwLock<HashMap<OwnedRoomId, Arc<Mutex<()>>>>, // this lock will be held longer
|
||||
pub roomid_federationhandletime: RwLock<HashMap<OwnedRoomId, (OwnedEventId, Instant)>>,
|
||||
pub stateres_mutex: Arc<Mutex<()>>,
|
||||
pub(crate) rotate: RotationHandler,
|
||||
|
@ -109,11 +109,11 @@ impl Default for RotationHandler {
|
|||
|
||||
struct Resolver {
|
||||
inner: GaiResolver,
|
||||
overrides: Arc<RwLock<TlsNameMap>>,
|
||||
overrides: Arc<StdRwLock<TlsNameMap>>,
|
||||
}
|
||||
|
||||
impl Resolver {
|
||||
fn new(overrides: Arc<RwLock<TlsNameMap>>) -> Self {
|
||||
fn new(overrides: Arc<StdRwLock<TlsNameMap>>) -> Self {
|
||||
Resolver {
|
||||
inner: GaiResolver::new(),
|
||||
overrides,
|
||||
|
@ -125,7 +125,7 @@ impl Resolve for Resolver {
|
|||
fn resolve(&self, name: Name) -> Resolving {
|
||||
self.overrides
|
||||
.read()
|
||||
.expect("lock should not be poisoned")
|
||||
.unwrap()
|
||||
.get(name.as_str())
|
||||
.and_then(|(override_name, port)| {
|
||||
override_name.first().map(|first_name| {
|
||||
|
@ -159,7 +159,7 @@ impl Service<'_> {
|
|||
},
|
||||
};
|
||||
|
||||
let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new()));
|
||||
let tls_name_override = Arc::new(StdRwLock::new(TlsNameMap::new()));
|
||||
|
||||
let jwt_decoding_key =
|
||||
config.jwt_secret.as_ref().map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
|
||||
|
|
|
@ -1,10 +1,5 @@
|
|||
mod data;
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
io::Cursor,
|
||||
sync::{Arc, RwLock},
|
||||
time::SystemTime,
|
||||
};
|
||||
use std::{collections::HashMap, io::Cursor, sync::Arc, time::SystemTime};
|
||||
|
||||
pub(crate) use data::Data;
|
||||
use image::imageops::FilterType;
|
||||
|
@ -13,7 +8,7 @@ use serde::Serialize;
|
|||
use tokio::{
|
||||
fs::{self, File},
|
||||
io::{AsyncReadExt, AsyncWriteExt, BufReader},
|
||||
sync::Mutex,
|
||||
sync::{Mutex, RwLock},
|
||||
};
|
||||
use tracing::{debug, error};
|
||||
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use std::{
|
||||
collections::{BTreeMap, HashMap},
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
sync::{Arc, Mutex as StdMutex},
|
||||
};
|
||||
|
||||
use lru_cache::LruCache;
|
||||
use tokio::sync::{Mutex, RwLock};
|
||||
|
||||
use crate::{Config, Result};
|
||||
|
||||
|
@ -106,10 +107,10 @@ impl Services<'_> {
|
|||
},
|
||||
state_accessor: rooms::state_accessor::Service {
|
||||
db,
|
||||
server_visibility_cache: Mutex::new(LruCache::new(
|
||||
server_visibility_cache: StdMutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
user_visibility_cache: Mutex::new(LruCache::new(
|
||||
user_visibility_cache: StdMutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
|
@ -118,7 +119,7 @@ impl Services<'_> {
|
|||
},
|
||||
state_compressor: rooms::state_compressor::Service {
|
||||
db,
|
||||
stateinfo_cache: Mutex::new(LruCache::new(
|
||||
stateinfo_cache: StdMutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
|
@ -146,7 +147,7 @@ impl Services<'_> {
|
|||
},
|
||||
users: users::Service {
|
||||
db,
|
||||
connections: Mutex::new(BTreeMap::new()),
|
||||
connections: StdMutex::new(BTreeMap::new()),
|
||||
},
|
||||
account_data: account_data::Service {
|
||||
db,
|
||||
|
@ -165,13 +166,13 @@ impl Services<'_> {
|
|||
})
|
||||
}
|
||||
|
||||
fn memory_usage(&self) -> String {
|
||||
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().len();
|
||||
async fn memory_usage(&self) -> String {
|
||||
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().await.len();
|
||||
let server_visibility_cache = self.rooms.state_accessor.server_visibility_cache.lock().unwrap().len();
|
||||
let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len();
|
||||
let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len();
|
||||
let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().len();
|
||||
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().len();
|
||||
let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().await.len();
|
||||
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().await.len();
|
||||
|
||||
format!(
|
||||
"\
|
||||
|
@ -184,9 +185,9 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}"
|
|||
)
|
||||
}
|
||||
|
||||
fn clear_caches(&self, amount: u32) {
|
||||
async fn clear_caches(&self, amount: u32) {
|
||||
if amount > 0 {
|
||||
self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().clear();
|
||||
self.rooms.lazy_loading.lazy_load_waiting.lock().await.clear();
|
||||
}
|
||||
if amount > 1 {
|
||||
self.rooms.state_accessor.server_visibility_cache.lock().unwrap().clear();
|
||||
|
@ -198,10 +199,10 @@ roomid_spacechunk_cache: {roomid_spacechunk_cache}"
|
|||
self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 4 {
|
||||
self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().clear();
|
||||
self.rooms.timeline.lasttimelinecount_cache.lock().await.clear();
|
||||
}
|
||||
if amount > 5 {
|
||||
self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().clear();
|
||||
self.rooms.spaces.roomid_spacechunk_cache.lock().await.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -4,7 +4,6 @@ type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
|
|||
use std::{
|
||||
collections::{hash_map, HashSet},
|
||||
pin::Pin,
|
||||
sync::RwLockWriteGuard,
|
||||
time::{Duration, Instant, SystemTime},
|
||||
};
|
||||
|
||||
|
@ -33,12 +32,12 @@ use ruma::{
|
|||
OwnedServerSigningKeyId, RoomId, RoomVersionId, ServerName,
|
||||
};
|
||||
use serde_json::value::RawValue as RawJsonValue;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore};
|
||||
use tracing::{debug, error, info, trace, warn};
|
||||
|
||||
use super::state_compressor::CompressedStateEvent;
|
||||
use crate::{
|
||||
service::{pdu, Arc, BTreeMap, HashMap, Result, RwLock},
|
||||
service::{pdu, Arc, BTreeMap, HashMap, Result},
|
||||
services, Error, PduEvent,
|
||||
};
|
||||
|
||||
|
@ -168,7 +167,7 @@ impl Service {
|
|||
));
|
||||
}
|
||||
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*prev_id) {
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*prev_id) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -183,7 +182,7 @@ impl Service {
|
|||
|
||||
if errors >= 5 {
|
||||
// Timeout other events
|
||||
match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) {
|
||||
match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
|
@ -205,7 +204,7 @@ impl Service {
|
|||
.globals
|
||||
.roomid_federationhandletime
|
||||
.write()
|
||||
.unwrap()
|
||||
.await
|
||||
.insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time));
|
||||
|
||||
if let Err(e) =
|
||||
|
@ -213,7 +212,7 @@ impl Service {
|
|||
{
|
||||
errors += 1;
|
||||
warn!("Prev event {} failed: {}", prev_id, e);
|
||||
match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) {
|
||||
match services().globals.bad_event_ratelimiter.write().await.entry((*prev_id).to_owned()) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
|
@ -223,7 +222,7 @@ impl Service {
|
|||
}
|
||||
}
|
||||
let elapsed = start_time.elapsed();
|
||||
services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned());
|
||||
services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned());
|
||||
debug!(
|
||||
"Handling prev event {} took {}m{}s",
|
||||
prev_id,
|
||||
|
@ -240,14 +239,14 @@ impl Service {
|
|||
.globals
|
||||
.roomid_federationhandletime
|
||||
.write()
|
||||
.unwrap()
|
||||
.await
|
||||
.insert(room_id.to_owned(), (event_id.to_owned(), start_time));
|
||||
let r = services()
|
||||
.rooms
|
||||
.event_handler
|
||||
.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map)
|
||||
.await;
|
||||
services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned());
|
||||
services().globals.roomid_federationhandletime.write().await.remove(&room_id.to_owned());
|
||||
|
||||
r
|
||||
}
|
||||
|
@ -275,11 +274,8 @@ impl Service {
|
|||
let room_version_id = &create_event_content.room_version;
|
||||
let room_version = RoomVersion::new(room_version_id).expect("room version is supported");
|
||||
|
||||
let mut val = match ruma::signatures::verify_event(
|
||||
&pub_key_map.read().expect("RwLock is poisoned."),
|
||||
&value,
|
||||
room_version_id,
|
||||
) {
|
||||
let guard = pub_key_map.read().await;
|
||||
let mut val = match ruma::signatures::verify_event(&guard, &value, room_version_id) {
|
||||
Err(e) => {
|
||||
// Drop
|
||||
warn!("Dropping bad event {}: {}", event_id, e,);
|
||||
|
@ -306,6 +302,8 @@ impl Service {
|
|||
Ok(ruma::signatures::Verified::All) => value,
|
||||
};
|
||||
|
||||
drop(guard);
|
||||
|
||||
// Now that we have checked the signature and hashes we can add the eventID and
|
||||
// convert to our PduEvent type
|
||||
val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
|
||||
|
@ -580,10 +578,13 @@ impl Service {
|
|||
{
|
||||
Ok(res) => {
|
||||
debug!("Fetching state events at event.");
|
||||
|
||||
let collect = res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::<Vec<_>>();
|
||||
|
||||
let state_vec = self
|
||||
.fetch_and_handle_outliers(
|
||||
origin,
|
||||
&res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::<Vec<_>>(),
|
||||
&collect,
|
||||
create_event,
|
||||
room_id,
|
||||
room_version_id,
|
||||
|
@ -680,7 +681,7 @@ impl Service {
|
|||
|
||||
// We start looking at current room state now, so lets lock the room
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().await.entry(room_id.to_owned()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
// Now we calculate the set of extremities this room has after the incoming
|
||||
|
@ -876,11 +877,13 @@ impl Service {
|
|||
room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
|
||||
) -> AsyncRecursiveCanonicalJsonVec<'a> {
|
||||
Box::pin(async move {
|
||||
let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
let back_off = |id| async {
|
||||
match services().globals.bad_event_ratelimiter.write().await.entry(id) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
}
|
||||
};
|
||||
|
||||
let mut events_with_auth_events = vec![];
|
||||
|
@ -902,8 +905,7 @@ impl Service {
|
|||
let mut events_all = HashSet::new();
|
||||
let mut i = 0;
|
||||
while let Some(next_id) = todo_auth_events.pop() {
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*next_id)
|
||||
{
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&*next_id) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -1009,9 +1011,7 @@ impl Service {
|
|||
pdus.push((local_pdu, None));
|
||||
}
|
||||
for (next_id, value) in events_in_reverse_order.iter().rev() {
|
||||
if let Some((time, tries)) =
|
||||
services().globals.bad_event_ratelimiter.read().unwrap().get(&**next_id)
|
||||
{
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(&**next_id) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -1205,10 +1205,7 @@ impl Service {
|
|||
while let Some(fetch_res) = server_keys.next().await {
|
||||
match fetch_res {
|
||||
Ok((signature_server, keys)) => {
|
||||
pub_key_map
|
||||
.write()
|
||||
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
|
||||
.insert(signature_server.clone(), keys);
|
||||
pub_key_map.write().await.insert(signature_server.clone(), keys);
|
||||
},
|
||||
Err((signature_server, e)) => {
|
||||
warn!("Failed to fetch keys for {}: {:?}", signature_server, e);
|
||||
|
@ -1222,7 +1219,7 @@ impl Service {
|
|||
// Gets a list of servers for which we don't have the signing key yet. We go
|
||||
// over the PDUs and either cache the key or add it to the list that needs to be
|
||||
// retrieved.
|
||||
fn get_server_keys_from_cache(
|
||||
async fn get_server_keys_from_cache(
|
||||
&self, pdu: &RawJsonValue,
|
||||
servers: &mut BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>>,
|
||||
room_version: &RoomVersionId,
|
||||
|
@ -1239,7 +1236,7 @@ impl Service {
|
|||
);
|
||||
let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids");
|
||||
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(event_id) {
|
||||
if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().await.get(event_id) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -1310,7 +1307,7 @@ impl Service {
|
|||
.await
|
||||
{
|
||||
debug!("Got signing keys: {:?}", keys);
|
||||
let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?;
|
||||
let mut pkm = pub_key_map.write().await;
|
||||
for k in keys.server_keys {
|
||||
let k = match k.deserialize() {
|
||||
Ok(key) => key,
|
||||
|
@ -1365,10 +1362,7 @@ impl Service {
|
|||
.into_iter()
|
||||
.map(|(k, v)| (k.to_string(), v.key))
|
||||
.collect();
|
||||
pub_key_map
|
||||
.write()
|
||||
.map_err(|_| Error::bad_database("RwLock is poisoned."))?
|
||||
.insert(origin.to_string(), result);
|
||||
pub_key_map.write().await.insert(origin.to_string(), result);
|
||||
}
|
||||
}
|
||||
debug!("Done handling Future result");
|
||||
|
@ -1384,15 +1378,15 @@ impl Service {
|
|||
let mut servers: BTreeMap<OwnedServerName, BTreeMap<OwnedServerSigningKeyId, QueryCriteria>> = BTreeMap::new();
|
||||
|
||||
{
|
||||
let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?;
|
||||
let mut pkm = pub_key_map.write().await;
|
||||
|
||||
// Try to fetch keys, failure is okay
|
||||
// Servers we couldn't find in the cache will be added to `servers`
|
||||
for pdu in &event.room_state.state {
|
||||
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
|
||||
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await;
|
||||
}
|
||||
for pdu in &event.room_state.auth_chain {
|
||||
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm);
|
||||
let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm).await;
|
||||
}
|
||||
|
||||
drop(pkm);
|
||||
|
@ -1491,18 +1485,13 @@ impl Service {
|
|||
) -> Result<BTreeMap<String, Base64>> {
|
||||
let contains_all_ids = |keys: &BTreeMap<String, Base64>| signature_ids.iter().all(|id| keys.contains_key(id));
|
||||
|
||||
let permit = services()
|
||||
.globals
|
||||
.servername_ratelimiter
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(origin)
|
||||
.map(|s| Arc::clone(s).acquire_owned());
|
||||
let permit =
|
||||
services().globals.servername_ratelimiter.read().await.get(origin).map(|s| Arc::clone(s).acquire_owned());
|
||||
|
||||
let permit = match permit {
|
||||
Some(p) => p,
|
||||
None => {
|
||||
let mut write = services().globals.servername_ratelimiter.write().unwrap();
|
||||
let mut write = services().globals.servername_ratelimiter.write().await;
|
||||
let s = Arc::clone(write.entry(origin.to_owned()).or_insert_with(|| Arc::new(Semaphore::new(1))));
|
||||
|
||||
s.acquire_owned()
|
||||
|
@ -1510,14 +1499,16 @@ impl Service {
|
|||
}
|
||||
.await;
|
||||
|
||||
let back_off = |id| match services().globals.bad_signature_ratelimiter.write().unwrap().entry(id) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
let back_off = |id| async {
|
||||
match services().globals.bad_signature_ratelimiter.write().await.entry(id) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
}
|
||||
};
|
||||
|
||||
if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().unwrap().get(&signature_ids) {
|
||||
if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().await.get(&signature_ids) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -1591,7 +1582,7 @@ impl Service {
|
|||
|
||||
drop(permit);
|
||||
|
||||
back_off(signature_ids);
|
||||
back_off(signature_ids).await;
|
||||
|
||||
warn!("Failed to find public key for server: {}", origin);
|
||||
Err(Error::BadServerResponse("Failed to find public key for server"))
|
||||
|
|
|
@ -1,21 +1,18 @@
|
|||
mod data;
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
sync::Mutex,
|
||||
};
|
||||
use std::collections::{HashMap, HashSet};
|
||||
|
||||
pub use data::Data;
|
||||
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
|
||||
use tokio::sync::Mutex;
|
||||
|
||||
use super::timeline::PduCount;
|
||||
use crate::Result;
|
||||
|
||||
type LazyLoadWaitingMutex = Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>;
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
|
||||
pub lazy_load_waiting: LazyLoadWaitingMutex,
|
||||
#[allow(clippy::type_complexity)]
|
||||
pub lazy_load_waiting: Mutex<HashMap<(OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount), HashSet<OwnedUserId>>>,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
|
@ -27,21 +24,21 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn lazy_load_mark_sent(
|
||||
pub async fn lazy_load_mark_sent(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>,
|
||||
count: PduCount,
|
||||
) {
|
||||
self.lazy_load_waiting
|
||||
.lock()
|
||||
.unwrap()
|
||||
.await
|
||||
.insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load);
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn lazy_load_confirm_delivery(
|
||||
pub async fn lazy_load_confirm_delivery(
|
||||
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount,
|
||||
) -> Result<()> {
|
||||
if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&(
|
||||
if let Some(user_ids) = self.lazy_load_waiting.lock().await.remove(&(
|
||||
user_id.to_owned(),
|
||||
device_id.to_owned(),
|
||||
room_id.to_owned(),
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
use std::sync::{Arc, Mutex};
|
||||
use std::sync::Arc;
|
||||
|
||||
use lru_cache::LruCache;
|
||||
use ruma::{
|
||||
|
@ -25,6 +25,7 @@ use ruma::{
|
|||
space::SpaceRoomJoinRule,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
};
|
||||
use tokio::sync::Mutex;
|
||||
use tracing::{debug, error, warn};
|
||||
|
||||
use crate::{services, Error, PduEvent, Result};
|
||||
|
@ -70,7 +71,7 @@ impl Service {
|
|||
break;
|
||||
}
|
||||
|
||||
if let Some(cached) = self.roomid_spacechunk_cache.lock().unwrap().get_mut(¤t_room.clone()).as_ref() {
|
||||
if let Some(cached) = self.roomid_spacechunk_cache.lock().await.get_mut(¤t_room.clone()).as_ref() {
|
||||
if let Some(cached) = cached {
|
||||
let allowed = match &cached.join_rule {
|
||||
//CachedJoinRule::Simplified(s) => {
|
||||
|
@ -148,7 +149,7 @@ impl Service {
|
|||
.transpose()?
|
||||
.unwrap_or(JoinRule::Invite);
|
||||
|
||||
self.roomid_spacechunk_cache.lock().unwrap().insert(
|
||||
self.roomid_spacechunk_cache.lock().await.insert(
|
||||
current_room.clone(),
|
||||
Some(CachedSpaceChunk {
|
||||
chunk,
|
||||
|
@ -223,7 +224,7 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
self.roomid_spacechunk_cache.lock().unwrap().insert(
|
||||
self.roomid_spacechunk_cache.lock().await.insert(
|
||||
current_room.clone(),
|
||||
Some(CachedSpaceChunk {
|
||||
chunk,
|
||||
|
@ -245,7 +246,7 @@ impl Service {
|
|||
}
|
||||
*/
|
||||
} else {
|
||||
self.roomid_spacechunk_cache.lock().unwrap().insert(current_room.clone(), None);
|
||||
self.roomid_spacechunk_cache.lock().await.insert(current_room.clone(), None);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -74,7 +74,7 @@ impl Service {
|
|||
.await?;
|
||||
},
|
||||
TimelineEventType::SpaceChild => {
|
||||
services().rooms.spaces.roomid_spacechunk_cache.lock().unwrap().remove(&pdu.room_id);
|
||||
services().rooms.spaces.roomid_spacechunk_cache.lock().await.remove(&pdu.room_id);
|
||||
},
|
||||
_ => continue,
|
||||
}
|
||||
|
|
|
@ -3,7 +3,7 @@ pub(crate) mod data;
|
|||
use std::{
|
||||
cmp::Ordering,
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
pub use data::Data;
|
||||
|
@ -30,7 +30,7 @@ use ruma::{
|
|||
};
|
||||
use serde::Deserialize;
|
||||
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
|
||||
use tokio::sync::MutexGuard;
|
||||
use tokio::sync::{Mutex, MutexGuard, RwLock};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use super::state_compressor::CompressedStateEvent;
|
||||
|
@ -239,7 +239,7 @@ impl Service {
|
|||
services().rooms.state.set_forward_extremities(&pdu.room_id, leaves, state_lock)?;
|
||||
|
||||
let mutex_insert =
|
||||
Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(pdu.room_id.clone()).or_default());
|
||||
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(pdu.room_id.clone()).or_default());
|
||||
let insert_lock = mutex_insert.lock().await;
|
||||
|
||||
let count1 = services().globals.next_count()?;
|
||||
|
@ -398,7 +398,7 @@ impl Service {
|
|||
},
|
||||
TimelineEventType::SpaceChild => {
|
||||
if let Some(_state_key) = &pdu.state_key {
|
||||
services().rooms.spaces.roomid_spacechunk_cache.lock().unwrap().remove(&pdu.room_id);
|
||||
services().rooms.spaces.roomid_spacechunk_cache.lock().await.remove(&pdu.room_id);
|
||||
}
|
||||
},
|
||||
TimelineEventType::RoomMember => {
|
||||
|
@ -997,7 +997,7 @@ impl Service {
|
|||
|
||||
// Lock so we cannot backfill the same pdu twice at the same time
|
||||
let mutex =
|
||||
Arc::clone(services().globals.roomid_mutex_federation.write().unwrap().entry(room_id.clone()).or_default());
|
||||
Arc::clone(services().globals.roomid_mutex_federation.write().await.entry(room_id.clone()).or_default());
|
||||
let mutex_lock = mutex.lock().await;
|
||||
|
||||
// Skip the PDU if we already have it as a timeline event
|
||||
|
@ -1020,7 +1020,7 @@ impl Service {
|
|||
let shortroomid = services().rooms.short.get_shortroomid(&room_id)?.expect("room exists");
|
||||
|
||||
let mutex_insert =
|
||||
Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.clone()).or_default());
|
||||
Arc::clone(services().globals.roomid_mutex_insert.write().await.entry(room_id.clone()).or_default());
|
||||
let insert_lock = mutex_insert.lock().await;
|
||||
|
||||
let count = services().globals.next_count()?;
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue