cache compressed state in a sorted structure for logarithmic queries with partial keys

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-01-31 15:50:09 +00:00
parent ea49b60273
commit 4add39d0fe
7 changed files with 118 additions and 73 deletions

View file

@ -46,7 +46,10 @@ use ruma::{
use service::{ use service::{
appservice::RegistrationInfo, appservice::RegistrationInfo,
pdu::gen_event_id, pdu::gen_event_id,
rooms::{state::RoomMutexGuard, state_compressor::HashSetCompressStateEvent}, rooms::{
state::RoomMutexGuard,
state_compressor::{CompressedState, HashSetCompressStateEvent},
},
Services, Services,
}; };
@ -1169,7 +1172,7 @@ async fn join_room_by_id_helper_remote(
} }
info!("Compressing state from send_join"); info!("Compressing state from send_join");
let compressed: HashSet<_> = services let compressed: CompressedState = services
.rooms .rooms
.state_compressor .state_compressor
.compress_state_events(state.iter().map(|(ssk, eid)| (ssk, eid.borrow()))) .compress_state_events(state.iter().map(|(ssk, eid)| (ssk, eid.borrow())))
@ -2340,7 +2343,7 @@ async fn knock_room_helper_remote(
} }
info!("Compressing state from send_knock"); info!("Compressing state from send_knock");
let compressed: HashSet<_> = services let compressed: CompressedState = services
.rooms .rooms
.state_compressor .state_compressor
.compress_state_events(state_map.iter().map(|(ssk, eid)| (ssk, eid.borrow()))) .compress_state_events(state_map.iter().map(|(ssk, eid)| (ssk, eid.borrow())))

View file

@ -15,7 +15,7 @@ use ruma::{
OwnedEventId, RoomId, RoomVersionId, OwnedEventId, RoomId, RoomVersionId,
}; };
use crate::rooms::state_compressor::CompressedStateEvent; use crate::rooms::state_compressor::CompressedState;
#[implement(super::Service)] #[implement(super::Service)]
#[tracing::instrument(name = "resolve", level = "debug", skip_all)] #[tracing::instrument(name = "resolve", level = "debug", skip_all)]
@ -24,7 +24,7 @@ pub async fn resolve_state(
room_id: &RoomId, room_id: &RoomId,
room_version_id: &RoomVersionId, room_version_id: &RoomVersionId,
incoming_state: HashMap<u64, OwnedEventId>, incoming_state: HashMap<u64, OwnedEventId>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> { ) -> Result<Arc<CompressedState>> {
trace!("Loading current room state ids"); trace!("Loading current room state ids");
let current_sstatehash = self let current_sstatehash = self
.services .services
@ -91,7 +91,7 @@ pub async fn resolve_state(
.await; .await;
trace!("Compressing state..."); trace!("Compressing state...");
let new_room_state: HashSet<_> = self let new_room_state: CompressedState = self
.services .services
.state_compressor .state_compressor
.compress_state_events( .compress_state_events(

View file

@ -1,10 +1,4 @@
use std::{ use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::Instant};
borrow::Borrow,
collections::{BTreeMap, HashSet},
iter::once,
sync::Arc,
time::Instant,
};
use conduwuit::{ use conduwuit::{
debug, debug_info, err, implement, trace, debug, debug_info, err, implement, trace,
@ -19,7 +13,10 @@ use ruma::{
}; };
use super::{get_room_version_id, to_room_version}; use super::{get_room_version_id, to_room_version};
use crate::rooms::{state_compressor::HashSetCompressStateEvent, timeline::RawPduId}; use crate::rooms::{
state_compressor::{CompressedState, HashSetCompressStateEvent},
timeline::RawPduId,
};
#[implement(super::Service)] #[implement(super::Service)]
pub(super) async fn upgrade_outlier_to_timeline_pdu( pub(super) async fn upgrade_outlier_to_timeline_pdu(
@ -173,7 +170,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
incoming_pdu.prev_events.len() incoming_pdu.prev_events.len()
); );
let state_ids_compressed: Arc<HashSet<_>> = self let state_ids_compressed: Arc<CompressedState> = self
.services .services
.state_compressor .state_compressor
.compress_state_events( .compress_state_events(

View file

@ -1,9 +1,4 @@
use std::{ use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc};
collections::{HashMap, HashSet},
fmt::Write,
iter::once,
sync::Arc,
};
use conduwuit::{ use conduwuit::{
err, err,
@ -33,7 +28,7 @@ use crate::{
globals, rooms, globals, rooms,
rooms::{ rooms::{
short::{ShortEventId, ShortStateHash}, short::{ShortEventId, ShortStateHash},
state_compressor::{parse_compressed_state_event, CompressedStateEvent}, state_compressor::{parse_compressed_state_event, CompressedState},
}, },
Dep, Dep,
}; };
@ -102,10 +97,9 @@ impl Service {
&self, &self,
room_id: &RoomId, room_id: &RoomId,
shortstatehash: u64, shortstatehash: u64,
statediffnew: Arc<HashSet<CompressedStateEvent>>, statediffnew: Arc<CompressedState>,
_statediffremoved: Arc<HashSet<CompressedStateEvent>>, _statediffremoved: Arc<CompressedState>,
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state state_lock: &RoomMutexGuard,
* mutex */
) -> Result { ) -> Result {
let event_ids = statediffnew let event_ids = statediffnew
.iter() .iter()
@ -176,7 +170,7 @@ impl Service {
&self, &self,
event_id: &EventId, event_id: &EventId,
room_id: &RoomId, room_id: &RoomId,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, state_ids_compressed: Arc<CompressedState>,
) -> Result<ShortStateHash> { ) -> Result<ShortStateHash> {
const KEY_LEN: usize = size_of::<ShortEventId>(); const KEY_LEN: usize = size_of::<ShortEventId>();
const VAL_LEN: usize = size_of::<ShortStateHash>(); const VAL_LEN: usize = size_of::<ShortStateHash>();
@ -209,12 +203,12 @@ impl Service {
let (statediffnew, statediffremoved) = let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() { if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = state_ids_compressed let statediffnew: CompressedState = state_ids_compressed
.difference(&parent_stateinfo.full_state) .difference(&parent_stateinfo.full_state)
.copied() .copied()
.collect(); .collect();
let statediffremoved: HashSet<_> = parent_stateinfo let statediffremoved: CompressedState = parent_stateinfo
.full_state .full_state
.difference(&state_ids_compressed) .difference(&state_ids_compressed)
.copied() .copied()
@ -222,7 +216,7 @@ impl Service {
(Arc::new(statediffnew), Arc::new(statediffremoved)) (Arc::new(statediffnew), Arc::new(statediffremoved))
} else { } else {
(state_ids_compressed, Arc::new(HashSet::new())) (state_ids_compressed, Arc::new(CompressedState::new()))
}; };
self.services.state_compressor.save_state_from_diff( self.services.state_compressor.save_state_from_diff(
shortstatehash, shortstatehash,
@ -300,10 +294,10 @@ impl Service {
// TODO: statehash with deterministic inputs // TODO: statehash with deterministic inputs
let shortstatehash = self.services.globals.next_count()?; let shortstatehash = self.services.globals.next_count()?;
let mut statediffnew = HashSet::new(); let mut statediffnew = CompressedState::new();
statediffnew.insert(new); statediffnew.insert(new);
let mut statediffremoved = HashSet::new(); let mut statediffremoved = CompressedState::new();
if let Some(replaces) = replaces { if let Some(replaces) = replaces {
statediffremoved.insert(*replaces); statediffremoved.insert(*replaces);
} }

View file

@ -11,6 +11,7 @@ use conduwuit::{
utils, utils,
utils::{ utils::{
math::{usize_from_f64, Expected}, math::{usize_from_f64, Expected},
result::FlatOk,
stream::{BroadbandExt, IterStream, ReadyExt, TryExpect}, stream::{BroadbandExt, IterStream, ReadyExt, TryExpect},
}, },
Err, Error, PduEvent, Result, Err, Error, PduEvent, Result,
@ -47,7 +48,7 @@ use crate::{
rooms::{ rooms::{
short::{ShortEventId, ShortStateHash, ShortStateKey}, short::{ShortEventId, ShortStateHash, ShortStateKey},
state::RoomMutexGuard, state::RoomMutexGuard,
state_compressor::parse_compressed_state_event, state_compressor::{compress_state_event, parse_compressed_state_event},
}, },
Dep, Dep,
}; };
@ -220,36 +221,88 @@ impl Service {
Id: for<'de> Deserialize<'de> + Sized + ToOwned, Id: for<'de> Deserialize<'de> + Sized + ToOwned,
<Id as ToOwned>::Owned: Borrow<EventId>, <Id as ToOwned>::Owned: Borrow<EventId>,
{ {
let shortstatekey = self let shorteventid = self
.services .state_get_shortid(shortstatehash, event_type, state_key)
.short
.get_shortstatekey(event_type, state_key)
.await?; .await?;
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))?
.pop()
.expect("there is always one layer")
.full_state;
let compressed = full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.ok_or(err!(Database("No shortstatekey in compressed state")))?;
let (_, shorteventid) = parse_compressed_state_event(*compressed);
self.services self.services
.short .short
.get_eventid_from_short(shorteventid) .get_eventid_from_short(shorteventid)
.await .await
} }
#[inline] /// Returns a single EventId from `room_id` with key (`event_type`,
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_get_shortid(
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> Result<ShortEventId> {
let shortstatekey = self
.services
.short
.get_shortstatekey(event_type, state_key)
.await?;
let start = compress_state_event(shortstatekey, 0);
let end = compress_state_event(shortstatekey, u64::MAX);
self.services
.state_compressor
.load_shortstatehash_info(shortstatehash)
.map_ok(|vec| vec.last().expect("at least one layer").full_state.clone())
.map_ok(|full_state| {
full_state
.range(start..end)
.next()
.copied()
.map(parse_compressed_state_event)
.map(at!(1))
.ok_or(err!(Request(NotFound("Not found in room state"))))
})
.await?
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_contains(
&self,
shortstatehash: ShortStateHash,
event_type: &StateEventType,
state_key: &str,
) -> bool {
let Ok(shortstatekey) = self
.services
.short
.get_shortstatekey(event_type, state_key)
.await
else {
return false;
};
self.state_contains_shortstatekey(shortstatehash, shortstatekey)
.await
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_contains_shortstatekey(
&self,
shortstatehash: ShortStateHash,
shortstatekey: ShortStateKey,
) -> bool {
let start = compress_state_event(shortstatekey, 0);
let end = compress_state_event(shortstatekey, u64::MAX);
self.services
.state_compressor
.load_shortstatehash_info(shortstatehash)
.map_ok(|vec| vec.last().expect("at least one layer").full_state.clone())
.map_ok(|full_state| full_state.range(start..end).next().copied())
.await
.flat_ok()
.is_some()
}
pub fn state_full_shortids( pub fn state_full_shortids(
&self, &self,
shortstatehash: ShortStateHash, shortstatehash: ShortStateHash,

View file

@ -1,5 +1,5 @@
use std::{ use std::{
collections::{HashMap, HashSet}, collections::{BTreeSet, HashMap},
fmt::{Debug, Write}, fmt::{Debug, Write},
mem::size_of, mem::size_of,
sync::{Arc, Mutex}, sync::{Arc, Mutex},
@ -63,8 +63,8 @@ type StateInfoLruCache = LruCache<ShortStateHash, ShortStateInfoVec>;
type ShortStateInfoVec = Vec<ShortStateInfo>; type ShortStateInfoVec = Vec<ShortStateInfo>;
type ParentStatesVec = Vec<ShortStateInfo>; type ParentStatesVec = Vec<ShortStateInfo>;
pub(crate) type CompressedState = HashSet<CompressedStateEvent>; pub type CompressedState = BTreeSet<CompressedStateEvent>;
pub(crate) type CompressedStateEvent = [u8; 2 * size_of::<ShortId>()]; pub type CompressedStateEvent = [u8; 2 * size_of::<ShortId>()];
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
@ -249,8 +249,8 @@ impl Service {
pub fn save_state_from_diff( pub fn save_state_from_diff(
&self, &self,
shortstatehash: ShortStateHash, shortstatehash: ShortStateHash,
statediffnew: Arc<HashSet<CompressedStateEvent>>, statediffnew: Arc<CompressedState>,
statediffremoved: Arc<HashSet<CompressedStateEvent>>, statediffremoved: Arc<CompressedState>,
diff_to_sibling: usize, diff_to_sibling: usize,
mut parent_states: ParentStatesVec, mut parent_states: ParentStatesVec,
) -> Result { ) -> Result {
@ -363,7 +363,7 @@ impl Service {
pub async fn save_state( pub async fn save_state(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, new_state_ids_compressed: Arc<CompressedState>,
) -> Result<HashSetCompressStateEvent> { ) -> Result<HashSetCompressStateEvent> {
let previous_shortstatehash = self let previous_shortstatehash = self
.services .services
@ -396,12 +396,12 @@ impl Service {
let (statediffnew, statediffremoved) = let (statediffnew, statediffremoved) =
if let Some(parent_stateinfo) = states_parents.last() { if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = new_state_ids_compressed let statediffnew: CompressedState = new_state_ids_compressed
.difference(&parent_stateinfo.full_state) .difference(&parent_stateinfo.full_state)
.copied() .copied()
.collect(); .collect();
let statediffremoved: HashSet<_> = parent_stateinfo let statediffremoved: CompressedState = parent_stateinfo
.full_state .full_state
.difference(&new_state_ids_compressed) .difference(&new_state_ids_compressed)
.copied() .copied()
@ -409,7 +409,7 @@ impl Service {
(Arc::new(statediffnew), Arc::new(statediffremoved)) (Arc::new(statediffnew), Arc::new(statediffremoved))
} else { } else {
(new_state_ids_compressed, Arc::new(HashSet::new())) (new_state_ids_compressed, Arc::new(CompressedState::new()))
}; };
if !already_existed { if !already_existed {
@ -448,11 +448,11 @@ impl Service {
.take_if(|parent| *parent != 0); .take_if(|parent| *parent != 0);
debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride"); debug_assert!(value.len() % STRIDE == 0, "value not aligned to stride");
let num_values = value.len() / STRIDE; let _num_values = value.len() / STRIDE;
let mut add_mode = true; let mut add_mode = true;
let mut added = HashSet::with_capacity(num_values); let mut added = CompressedState::new();
let mut removed = HashSet::with_capacity(num_values); let mut removed = CompressedState::new();
let mut i = STRIDE; let mut i = STRIDE;
while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) { while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) {
@ -469,8 +469,6 @@ impl Service {
i = expected!(i + 2 * STRIDE); i = expected!(i + 2 * STRIDE);
} }
added.shrink_to_fit();
removed.shrink_to_fit();
Ok(StateDiff { Ok(StateDiff {
parent, parent,
added: Arc::new(added), added: Arc::new(added),
@ -507,7 +505,7 @@ impl Service {
#[inline] #[inline]
#[must_use] #[must_use]
fn compress_state_event( pub(crate) fn compress_state_event(
shortstatekey: ShortStateKey, shortstatekey: ShortStateKey,
shorteventid: ShortEventId, shorteventid: ShortEventId,
) -> CompressedStateEvent { ) -> CompressedStateEvent {
@ -523,7 +521,7 @@ fn compress_state_event(
#[inline] #[inline]
#[must_use] #[must_use]
pub fn parse_compressed_state_event( pub(crate) fn parse_compressed_state_event(
compressed_event: CompressedStateEvent, compressed_event: CompressedStateEvent,
) -> (ShortStateKey, ShortEventId) { ) -> (ShortStateKey, ShortEventId) {
use utils::u64_from_u8; use utils::u64_from_u8;

View file

@ -49,7 +49,7 @@ use crate::{
account_data, admin, appservice, account_data, admin, appservice,
appservice::NamespaceRegex, appservice::NamespaceRegex,
globals, pusher, rooms, globals, pusher, rooms,
rooms::{short::ShortRoomId, state_compressor::CompressedStateEvent}, rooms::{short::ShortRoomId, state_compressor::CompressedState},
sending, server_keys, users, Dep, sending, server_keys, users, Dep,
}; };
@ -950,7 +950,7 @@ impl Service {
pdu: &'a PduEvent, pdu: &'a PduEvent,
pdu_json: CanonicalJsonObject, pdu_json: CanonicalJsonObject,
new_room_leafs: Leafs, new_room_leafs: Leafs,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, state_ids_compressed: Arc<CompressedState>,
soft_fail: bool, soft_fail: bool,
state_lock: &'a RoomMutexGuard, state_lock: &'a RoomMutexGuard,
) -> Result<Option<RawPduId>> ) -> Result<Option<RawPduId>>