From 4add39d0fedcbe7946c6dfffac33d1e48111ea8b Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 31 Jan 2025 15:50:09 +0000 Subject: [PATCH] cache compressed state in a sorted structure for logarithmic queries with partial keys Signed-off-by: Jason Volk --- src/api/client/membership.rs | 9 +- .../rooms/event_handler/resolve_state.rs | 6 +- .../event_handler/upgrade_outlier_pdu.rs | 15 ++- src/service/rooms/state/mod.rs | 28 +++--- src/service/rooms/state_accessor/mod.rs | 99 ++++++++++++++----- src/service/rooms/state_compressor/mod.rs | 30 +++--- src/service/rooms/timeline/mod.rs | 4 +- 7 files changed, 118 insertions(+), 73 deletions(-) diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index d80aff0c..449d44d5 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -46,7 +46,10 @@ use ruma::{ use service::{ appservice::RegistrationInfo, pdu::gen_event_id, - rooms::{state::RoomMutexGuard, state_compressor::HashSetCompressStateEvent}, + rooms::{ + state::RoomMutexGuard, + state_compressor::{CompressedState, HashSetCompressStateEvent}, + }, Services, }; @@ -1169,7 +1172,7 @@ async fn join_room_by_id_helper_remote( } info!("Compressing state from send_join"); - let compressed: HashSet<_> = services + let compressed: CompressedState = services .rooms .state_compressor .compress_state_events(state.iter().map(|(ssk, eid)| (ssk, eid.borrow()))) @@ -2340,7 +2343,7 @@ async fn knock_room_helper_remote( } info!("Compressing state from send_knock"); - let compressed: HashSet<_> = services + let compressed: CompressedState = services .rooms .state_compressor .compress_state_events(state_map.iter().map(|(ssk, eid)| (ssk, eid.borrow()))) diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs index c3de5f2f..4d99b088 100644 --- a/src/service/rooms/event_handler/resolve_state.rs +++ b/src/service/rooms/event_handler/resolve_state.rs @@ -15,7 +15,7 @@ use ruma::{ OwnedEventId, RoomId, RoomVersionId, }; -use crate::rooms::state_compressor::CompressedStateEvent; +use crate::rooms::state_compressor::CompressedState; #[implement(super::Service)] #[tracing::instrument(name = "resolve", level = "debug", skip_all)] @@ -24,7 +24,7 @@ pub async fn resolve_state( room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap, -) -> Result>> { +) -> Result> { trace!("Loading current room state ids"); let current_sstatehash = self .services @@ -91,7 +91,7 @@ pub async fn resolve_state( .await; trace!("Compressing state..."); - let new_room_state: HashSet<_> = self + let new_room_state: CompressedState = self .services .state_compressor .compress_state_events( diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs index 03697558..132daca7 100644 --- a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -1,10 +1,4 @@ -use std::{ - borrow::Borrow, - collections::{BTreeMap, HashSet}, - iter::once, - sync::Arc, - time::Instant, -}; +use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::Instant}; use conduwuit::{ debug, debug_info, err, implement, trace, @@ -19,7 +13,10 @@ use ruma::{ }; use super::{get_room_version_id, to_room_version}; -use crate::rooms::{state_compressor::HashSetCompressStateEvent, timeline::RawPduId}; +use crate::rooms::{ + state_compressor::{CompressedState, HashSetCompressStateEvent}, + timeline::RawPduId, +}; #[implement(super::Service)] pub(super) async fn upgrade_outlier_to_timeline_pdu( @@ -173,7 +170,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu( incoming_pdu.prev_events.len() ); - let state_ids_compressed: Arc> = self + let state_ids_compressed: Arc = self .services .state_compressor .compress_state_events( diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 1b0d0d58..de90a89c 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,9 +1,4 @@ -use std::{ - collections::{HashMap, HashSet}, - fmt::Write, - iter::once, - sync::Arc, -}; +use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc}; use conduwuit::{ err, @@ -33,7 +28,7 @@ use crate::{ globals, rooms, rooms::{ short::{ShortEventId, ShortStateHash}, - state_compressor::{parse_compressed_state_event, CompressedStateEvent}, + state_compressor::{parse_compressed_state_event, CompressedState}, }, Dep, }; @@ -102,10 +97,9 @@ impl Service { &self, room_id: &RoomId, shortstatehash: u64, - statediffnew: Arc>, - _statediffremoved: Arc>, - state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state - * mutex */ + statediffnew: Arc, + _statediffremoved: Arc, + state_lock: &RoomMutexGuard, ) -> Result { let event_ids = statediffnew .iter() @@ -176,7 +170,7 @@ impl Service { &self, event_id: &EventId, room_id: &RoomId, - state_ids_compressed: Arc>, + state_ids_compressed: Arc, ) -> Result { const KEY_LEN: usize = size_of::(); const VAL_LEN: usize = size_of::(); @@ -209,12 +203,12 @@ impl Service { let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew: HashSet<_> = state_ids_compressed + let statediffnew: CompressedState = state_ids_compressed .difference(&parent_stateinfo.full_state) .copied() .collect(); - let statediffremoved: HashSet<_> = parent_stateinfo + let statediffremoved: CompressedState = parent_stateinfo .full_state .difference(&state_ids_compressed) .copied() @@ -222,7 +216,7 @@ impl Service { (Arc::new(statediffnew), Arc::new(statediffremoved)) } else { - (state_ids_compressed, Arc::new(HashSet::new())) + (state_ids_compressed, Arc::new(CompressedState::new())) }; self.services.state_compressor.save_state_from_diff( shortstatehash, @@ -300,10 +294,10 @@ impl Service { // TODO: statehash with deterministic inputs let shortstatehash = self.services.globals.next_count()?; - let mut statediffnew = HashSet::new(); + let mut statediffnew = CompressedState::new(); statediffnew.insert(new); - let mut statediffremoved = HashSet::new(); + let mut statediffremoved = CompressedState::new(); if let Some(replaces) = replaces { statediffremoved.insert(*replaces); } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 98aac138..8b56c8b6 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -11,6 +11,7 @@ use conduwuit::{ utils, utils::{ math::{usize_from_f64, Expected}, + result::FlatOk, stream::{BroadbandExt, IterStream, ReadyExt, TryExpect}, }, Err, Error, PduEvent, Result, @@ -47,7 +48,7 @@ use crate::{ rooms::{ short::{ShortEventId, ShortStateHash, ShortStateKey}, state::RoomMutexGuard, - state_compressor::parse_compressed_state_event, + state_compressor::{compress_state_event, parse_compressed_state_event}, }, Dep, }; @@ -220,36 +221,88 @@ impl Service { Id: for<'de> Deserialize<'de> + Sized + ToOwned, ::Owned: Borrow, { - let shortstatekey = self - .services - .short - .get_shortstatekey(event_type, state_key) + let shorteventid = self + .state_get_shortid(shortstatehash, event_type, state_key) .await?; - let full_state = self - .services - .state_compressor - .load_shortstatehash_info(shortstatehash) - .await - .map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))? - .pop() - .expect("there is always one layer") - .full_state; - - let compressed = full_state - .iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .ok_or(err!(Database("No shortstatekey in compressed state")))?; - - let (_, shorteventid) = parse_compressed_state_event(*compressed); - self.services .short .get_eventid_from_short(shorteventid) .await } - #[inline] + /// Returns a single EventId from `room_id` with key (`event_type`, + /// `state_key`). + #[tracing::instrument(skip(self), level = "debug")] + pub async fn state_get_shortid( + &self, + shortstatehash: ShortStateHash, + event_type: &StateEventType, + state_key: &str, + ) -> Result { + let shortstatekey = self + .services + .short + .get_shortstatekey(event_type, state_key) + .await?; + + let start = compress_state_event(shortstatekey, 0); + let end = compress_state_event(shortstatekey, u64::MAX); + self.services + .state_compressor + .load_shortstatehash_info(shortstatehash) + .map_ok(|vec| vec.last().expect("at least one layer").full_state.clone()) + .map_ok(|full_state| { + full_state + .range(start..end) + .next() + .copied() + .map(parse_compressed_state_event) + .map(at!(1)) + .ok_or(err!(Request(NotFound("Not found in room state")))) + }) + .await? + } + + #[tracing::instrument(skip(self), level = "debug")] + pub async fn state_contains( + &self, + shortstatehash: ShortStateHash, + event_type: &StateEventType, + state_key: &str, + ) -> bool { + let Ok(shortstatekey) = self + .services + .short + .get_shortstatekey(event_type, state_key) + .await + else { + return false; + }; + + self.state_contains_shortstatekey(shortstatehash, shortstatekey) + .await + } + + #[tracing::instrument(skip(self), level = "debug")] + pub async fn state_contains_shortstatekey( + &self, + shortstatehash: ShortStateHash, + shortstatekey: ShortStateKey, + ) -> bool { + let start = compress_state_event(shortstatekey, 0); + let end = compress_state_event(shortstatekey, u64::MAX); + + self.services + .state_compressor + .load_shortstatehash_info(shortstatehash) + .map_ok(|vec| vec.last().expect("at least one layer").full_state.clone()) + .map_ok(|full_state| full_state.range(start..end).next().copied()) + .await + .flat_ok() + .is_some() + } + pub fn state_full_shortids( &self, shortstatehash: ShortStateHash, diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 532df360..3d68dff6 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,5 +1,5 @@ use std::{ - collections::{HashMap, HashSet}, + collections::{BTreeSet, HashMap}, fmt::{Debug, Write}, mem::size_of, sync::{Arc, Mutex}, @@ -63,8 +63,8 @@ type StateInfoLruCache = LruCache; type ShortStateInfoVec = Vec; type ParentStatesVec = Vec; -pub(crate) type CompressedState = HashSet; -pub(crate) type CompressedStateEvent = [u8; 2 * size_of::()]; +pub type CompressedState = BTreeSet; +pub type CompressedStateEvent = [u8; 2 * size_of::()]; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -249,8 +249,8 @@ impl Service { pub fn save_state_from_diff( &self, shortstatehash: ShortStateHash, - statediffnew: Arc>, - statediffremoved: Arc>, + statediffnew: Arc, + statediffremoved: Arc, diff_to_sibling: usize, mut parent_states: ParentStatesVec, ) -> Result { @@ -363,7 +363,7 @@ impl Service { pub async fn save_state( &self, room_id: &RoomId, - new_state_ids_compressed: Arc>, + new_state_ids_compressed: Arc, ) -> Result { let previous_shortstatehash = self .services @@ -396,12 +396,12 @@ impl Service { let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew: HashSet<_> = new_state_ids_compressed + let statediffnew: CompressedState = new_state_ids_compressed .difference(&parent_stateinfo.full_state) .copied() .collect(); - let statediffremoved: HashSet<_> = parent_stateinfo + let statediffremoved: CompressedState = parent_stateinfo .full_state .difference(&new_state_ids_compressed) .copied() @@ -409,7 +409,7 @@ impl Service { (Arc::new(statediffnew), Arc::new(statediffremoved)) } else { - (new_state_ids_compressed, Arc::new(HashSet::new())) + (new_state_ids_compressed, Arc::new(CompressedState::new())) }; if !already_existed { @@ -448,11 +448,11 @@ impl Service { .take_if(|parent| *parent != 0); debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride"); - let num_values = value.len() / STRIDE; + let _num_values = value.len() / STRIDE; let mut add_mode = true; - let mut added = HashSet::with_capacity(num_values); - let mut removed = HashSet::with_capacity(num_values); + let mut added = CompressedState::new(); + let mut removed = CompressedState::new(); let mut i = STRIDE; while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) { @@ -469,8 +469,6 @@ impl Service { i = expected!(i + 2 * STRIDE); } - added.shrink_to_fit(); - removed.shrink_to_fit(); Ok(StateDiff { parent, added: Arc::new(added), @@ -507,7 +505,7 @@ impl Service { #[inline] #[must_use] -fn compress_state_event( +pub(crate) fn compress_state_event( shortstatekey: ShortStateKey, shorteventid: ShortEventId, ) -> CompressedStateEvent { @@ -523,7 +521,7 @@ fn compress_state_event( #[inline] #[must_use] -pub fn parse_compressed_state_event( +pub(crate) fn parse_compressed_state_event( compressed_event: CompressedStateEvent, ) -> (ShortStateKey, ShortEventId) { use utils::u64_from_u8; diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 8b3b67a7..a913034d 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -49,7 +49,7 @@ use crate::{ account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, - rooms::{short::ShortRoomId, state_compressor::CompressedStateEvent}, + rooms::{short::ShortRoomId, state_compressor::CompressedState}, sending, server_keys, users, Dep, }; @@ -950,7 +950,7 @@ impl Service { pdu: &'a PduEvent, pdu_json: CanonicalJsonObject, new_room_leafs: Leafs, - state_ids_compressed: Arc>, + state_ids_compressed: Arc, soft_fail: bool, state_lock: &'a RoomMutexGuard, ) -> Result>