diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 65c9bc71..350e08c6 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -27,33 +27,32 @@ pub(super) async fn echo(&self, message: Vec) -> Result) -> Result { - let event_id = Arc::::from(event_id); - if let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await { - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + let Ok(event) = self.services.rooms.timeline.get_pdu_json(&event_id).await else { + return Ok(RoomMessageEventContent::notice_plain("Event not found.")); + }; - let room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - let start = Instant::now(); - let count = self - .services - .rooms - .auth_chain - .event_ids_iter(room_id, vec![event_id]) - .await? - .count() - .await; + let room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - let elapsed = start.elapsed(); - Ok(RoomMessageEventContent::text_plain(format!( - "Loaded auth chain with length {count} in {elapsed:?}" - ))) - } else { - Ok(RoomMessageEventContent::text_plain("Event not found.")) - } + let start = Instant::now(); + let count = self + .services + .rooms + .auth_chain + .event_ids_iter(room_id, &[&event_id]) + .await? + .count() + .await; + + let elapsed = start.elapsed(); + Ok(RoomMessageEventContent::text_plain(format!( + "Loaded auth chain with length {count} in {elapsed:?}" + ))) } #[admin_command] diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 6ec00b50..8307a4ad 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; use conduit::{Error, Result}; @@ -57,7 +57,7 @@ pub(crate) async fn get_event_authorization_route( let auth_chain = services .rooms .auth_chain - .event_ids_iter(room_id, vec![Arc::from(&*body.event_id)]) + .event_ids_iter(room_id, &[body.event_id.borrow()]) .await? .filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() }) .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 639fcafd..f9257690 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -1,6 +1,6 @@ #![allow(deprecated)] -use std::collections::BTreeMap; +use std::{borrow::Borrow, collections::BTreeMap}; use axum::extract::State; use conduit::{err, pdu::gen_event_id_canonical_json, utils::IterStream, warn, Error, Result}; @@ -11,7 +11,7 @@ use ruma::{ room::member::{MembershipState, RoomMemberEventContent}, StateEventType, }, - CanonicalJsonValue, OwnedServerName, OwnedUserId, RoomId, ServerName, + CanonicalJsonValue, EventId, OwnedServerName, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use service::Services; @@ -196,10 +196,11 @@ async fn create_join_event( .try_collect() .await?; + let starting_events: Vec<&EventId> = state_ids.values().map(Borrow::borrow).collect(); let auth_chain = services .rooms .auth_chain - .event_ids_iter(room_id, state_ids.values().cloned().collect()) + .event_ids_iter(room_id, &starting_events) .await? .map(Ok) .and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await }) diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 37a14a3f..3a27cd0a 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; use conduit::{err, result::LogErr, utils::IterStream, Err, Result}; @@ -63,7 +63,7 @@ pub(crate) async fn get_room_state_route( let auth_chain = services .rooms .auth_chain - .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) + .event_ids_iter(&body.room_id, &[body.event_id.borrow()]) .await? .map(Ok) .and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await }) diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 95ca65aa..b026abf1 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,4 +1,4 @@ -use std::sync::Arc; +use std::borrow::Borrow; use axum::extract::State; use conduit::{err, Err}; @@ -55,7 +55,7 @@ pub(crate) async fn get_room_state_ids_route( let auth_chain_ids = services .rooms .auth_chain - .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) + .event_ids_iter(&body.room_id, &[body.event_id.borrow()]) .await? .map(|id| (*id).to_owned()) .collect() diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 3d00374e..5c9dbda8 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -3,7 +3,7 @@ use std::{ sync::{Arc, Mutex}, }; -use conduit::{utils, utils::math::usize_from_f64, Result}; +use conduit::{err, utils, utils::math::usize_from_f64, Err, Result}; use database::Map; use lru_cache::LruCache; @@ -24,54 +24,63 @@ impl Data { } } - pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { + debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); + // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { - return Ok(Some(Arc::clone(result))); + if let Some(result) = self + .auth_chain_cache + .lock() + .expect("cache locked") + .get_mut(key) + { + return Ok(Arc::clone(result)); } // We only save auth chains for single events in the db - if key.len() == 1 { - // Check DB cache - let chain = self.shorteventid_authchain.qry(&key[0]).await.map(|chain| { - chain - .chunks_exact(size_of::()) - .map(utils::u64_from_u8) - .collect::>() - }); - - if let Ok(chain) = chain { - // Cache in RAM - self.auth_chain_cache - .lock() - .expect("locked") - .insert(vec![key[0]], Arc::clone(&chain)); - - return Ok(Some(chain)); - } + if key.len() != 1 { + return Err!(Request(NotFound("auth_chain not cached"))); } - Ok(None) + // Check database + let chain = self + .shorteventid_authchain + .qry(&key[0]) + .await + .map_err(|_| err!(Request(NotFound("auth_chain not found"))))?; + + let chain = chain + .chunks_exact(size_of::()) + .map(utils::u64_from_u8) + .collect::>(); + + // Cache in RAM + self.auth_chain_cache + .lock() + .expect("cache locked") + .insert(vec![key[0]], Arc::clone(&chain)); + + Ok(chain) } - pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) -> Result<()> { + pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) { + debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); + // Only persist single events in db if key.len() == 1 { - self.shorteventid_authchain.insert( - &key[0].to_be_bytes(), - &auth_chain - .iter() - .flat_map(|s| s.to_be_bytes().to_vec()) - .collect::>(), - ); + let key = key[0].to_be_bytes(); + let val = auth_chain + .iter() + .flat_map(|s| s.to_be_bytes().to_vec()) + .collect::>(); + + self.shorteventid_authchain.insert(&key, &val); } // Cache in RAM self.auth_chain_cache .lock() - .expect("locked") + .expect("cache locked") .insert(key, auth_chain); - - Ok(()) } } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 7bc239d7..eae13b74 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -37,25 +37,18 @@ impl crate::Service for Service { } impl Service { - pub async fn event_ids_iter<'a>( - &'a self, room_id: &RoomId, starting_events_: Vec>, - ) -> Result> + Send + 'a> { - let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len()); - for starting_event in &starting_events_ { - starting_events.push(starting_event); - } + pub async fn event_ids_iter( + &self, room_id: &RoomId, starting_events: &[&EventId], + ) -> Result> + Send + '_> { + let chain = self.get_auth_chain(room_id, starting_events).await?; + let iter = chain.into_iter().stream().filter_map(|sid| { + self.services + .short + .get_eventid_from_short(sid) + .map(Result::ok) + }); - Ok(self - .get_auth_chain(room_id, &starting_events) - .await? - .into_iter() - .stream() - .filter_map(|sid| { - self.services - .short - .get_eventid_from_short(sid) - .map(Result::ok) - })) + Ok(iter) } #[tracing::instrument(skip_all, name = "auth_chain")] @@ -93,7 +86,7 @@ impl Service { } let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key).await? { + if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await { trace!("Found cache entry for whole chunk"); full_auth_chain.extend(cached.iter().copied()); hits = hits.saturating_add(1); @@ -104,13 +97,13 @@ impl Service { let mut misses2: usize = 0; let mut chunk_cache = Vec::with_capacity(chunk.len()); for (sevent_id, event_id) in chunk { - if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await? { + if let Ok(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await { trace!(?event_id, "Found cache entry for event"); chunk_cache.extend(cached.iter().copied()); hits2 = hits2.saturating_add(1); } else { let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?; - self.cache_auth_chain(vec![sevent_id], &auth_chain)?; + self.cache_auth_chain(vec![sevent_id], &auth_chain); chunk_cache.extend(auth_chain.iter()); misses2 = misses2.saturating_add(1); debug!( @@ -125,7 +118,7 @@ impl Service { chunk_cache.sort_unstable(); chunk_cache.dedup(); - self.cache_auth_chain_vec(chunk_key, &chunk_cache)?; + self.cache_auth_chain_vec(chunk_key, &chunk_cache); full_auth_chain.extend(chunk_cache.iter()); misses = misses.saturating_add(1); debug!( @@ -163,11 +156,11 @@ impl Service { Ok(pdu) => { if pdu.room_id != room_id { return Err!(Request(Forbidden( - "auth event {event_id:?} for incorrect room {} which is not {}", + "auth event {event_id:?} for incorrect room {} which is not {room_id}", pdu.room_id, - room_id ))); } + for auth_event in &pdu.auth_events { let sauthevent = self .services @@ -187,20 +180,21 @@ impl Service { Ok(found) } - pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + #[inline] + pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { self.db.get_cached_eventid_authchain(key).await } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) -> Result<()> { - self.db - .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) + pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) { + let val = auth_chain.iter().copied().collect::>(); + self.db.cache_auth_chain(key, val); } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) -> Result<()> { - self.db - .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) + pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) { + let val = auth_chain.iter().copied().collect::>(); + self.db.cache_auth_chain(key, val); } pub fn get_cache_usage(&self) -> (usize, usize) { diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 07d6e4db..57b87706 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,6 +1,7 @@ mod parse_incoming_pdu; use std::{ + borrow::Borrow, collections::{hash_map, BTreeMap, HashMap, HashSet}, fmt::Write, sync::{Arc, RwLock as StdRwLock}, @@ -773,6 +774,7 @@ impl Service { Ok(pdu_id) } + #[tracing::instrument(skip_all, name = "resolve")] pub async fn resolve_state( &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap>, ) -> Result>> { @@ -793,14 +795,17 @@ impl Service { let fork_states = [current_state_ids, incoming_state]; let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); for state in &fork_states { - auth_chain_sets.push( - self.services - .auth_chain - .event_ids_iter(room_id, state.iter().map(|(_, id)| id.clone()).collect()) - .await? - .collect::>>() - .await, - ); + let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect(); + + let auth_chain = self + .services + .auth_chain + .event_ids_iter(room_id, &starting_events) + .await? + .collect::>>() + .await; + + auth_chain_sets.push(auth_chain); } debug!("Loading fork states"); @@ -962,12 +967,11 @@ impl Service { let mut state = StateMap::with_capacity(leaf_state.len()); let mut starting_events = Vec::with_capacity(leaf_state.len()); - - for (k, id) in leaf_state { + for (k, id) in &leaf_state { if let Ok((ty, st_key)) = self .services .short - .get_statekey_from_short(k) + .get_statekey_from_short(*k) .await .log_err() { @@ -976,18 +980,18 @@ impl Service { state.insert((ty.to_string().into(), st_key), id.clone()); } - starting_events.push(id); + starting_events.push(id.borrow()); } - auth_chain_sets.push( - self.services - .auth_chain - .event_ids_iter(room_id, starting_events) - .await? - .collect() - .await, - ); + let auth_chain = self + .services + .auth_chain + .event_ids_iter(room_id, &starting_events) + .await? + .collect() + .await; + auth_chain_sets.push(auth_chain); fork_states.push(state); }