diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index db102858..754c9840 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -12,6 +12,7 @@ use ruma::{ events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, }; +use service::rooms::state_compressor::HashSetCompressStateEvent; use tracing_subscriber::EnvFilter; use crate::admin_command; @@ -632,7 +633,11 @@ pub(super) async fn force_set_room_state_from_server( .await?; info!("Forcing new room state"); - let (short_state_hash, new, removed) = self + let HashSetCompressStateEvent { + shortstatehash: short_state_hash, + added, + removed, + } = self .services .rooms .state_compressor @@ -643,7 +648,7 @@ pub(super) async fn force_set_room_state_from_server( self.services .rooms .state - .force_state(room_id.clone().as_ref(), short_state_hash, new, removed, &state_lock) + .force_state(room_id.clone().as_ref(), short_state_hash, added, removed, &state_lock) .await?; info!( diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 27de60c6..c41e93fa 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -39,7 +39,11 @@ use ruma::{ state_res, CanonicalJsonObject, CanonicalJsonValue, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, }; -use service::{appservice::RegistrationInfo, rooms::state::RoomMutexGuard, Services}; +use service::{ + appservice::RegistrationInfo, + rooms::{state::RoomMutexGuard, state_compressor::HashSetCompressStateEvent}, + Services, +}; use crate::{client::full_user_deactivate, Ruma}; @@ -941,7 +945,11 @@ async fn join_room_by_id_helper_remote( .await; debug!("Saving compressed state"); - let (statehash_before_join, new, removed) = services + let HashSetCompressStateEvent { + shortstatehash: statehash_before_join, + added, + removed, + } = services .rooms .state_compressor .save_state(room_id, Arc::new(compressed)) @@ -951,7 +959,7 @@ async fn join_room_by_id_helper_remote( services .rooms .state - .force_state(room_id, statehash_before_join, new, removed, &state_lock) + .force_state(room_id, statehash_before_join, added, removed, &state_lock) .await?; info!("Updating joined counts for new room"); diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index ec04e748..adebd332 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -33,8 +33,11 @@ use ruma::{ RoomId, RoomVersionId, ServerName, UserId, }; -use super::state_compressor::CompressedStateEvent; -use crate::{globals, rooms, sending, server_keys, Dep}; +use crate::{ + globals, rooms, + rooms::state_compressor::{CompressedStateEvent, HashSetCompressStateEvent}, + sending, server_keys, Dep, +}; pub struct Service { services: Services, @@ -692,7 +695,11 @@ impl Service { // Set the new room state to the resolved state debug!("Forcing new room state"); - let (sstatehash, new, removed) = self + let HashSetCompressStateEvent { + shortstatehash, + added, + removed, + } = self .services .state_compressor .save_state(room_id, new_room_state) @@ -700,7 +707,7 @@ impl Service { self.services .state - .force_state(room_id, sstatehash, new, removed, &state_lock) + .force_state(room_id, shortstatehash, added, removed, &state_lock) .await?; } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 02c449cc..62011605 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -24,6 +24,7 @@ struct Services { globals: Dep, } +pub type ShortStateKey = ShortId; pub type ShortEventId = ShortId; pub type ShortRoomId = ShortId; pub type ShortId = u64; diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 6abaa198..34fab079 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -182,12 +182,12 @@ impl Service { let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = state_ids_compressed - .difference(&parent_stateinfo.1) + .difference(&parent_stateinfo.full_state) .copied() .collect(); let statediffremoved: HashSet<_> = parent_stateinfo - .1 + .full_state .difference(&state_ids_compressed) .copied() .collect(); @@ -259,7 +259,7 @@ impl Service { let replaces = states_parents .last() .map(|info| { - info.1 + info.full_state .iter() .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) }) @@ -421,7 +421,7 @@ impl Service { })? .pop() .expect("there is always one layer") - .1; + .full_state; let mut ret = HashMap::new(); for compressed in full_state.iter() { diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index adc26f00..f77a6d80 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -45,7 +45,7 @@ impl Data { .map_err(|e| err!(Database("Missing state IDs: {e}")))? .pop() .expect("there is always one layer") - .1; + .full_state; let mut result = HashMap::new(); let mut i: u8 = 0; @@ -78,7 +78,7 @@ impl Data { .await? .pop() .expect("there is always one layer") - .1; + .full_state; let mut result = HashMap::new(); let mut i: u8 = 0; @@ -123,7 +123,7 @@ impl Data { .map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))? .pop() .expect("there is always one layer") - .1; + .full_state; let compressed = full_state .iter() diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index be66c597..1f351f40 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -10,7 +10,7 @@ use database::Map; use lru_cache::LruCache; use ruma::{EventId, RoomId}; -use crate::{rooms, Dep}; +use crate::{rooms, rooms::short::ShortId, Dep}; pub struct Service { pub stateinfo_cache: Mutex, @@ -27,24 +27,33 @@ struct Data { shortstatehash_statediff: Arc, } +#[derive(Clone)] struct StateDiff { parent: Option, added: Arc>, removed: Arc>, } +#[derive(Clone, Default)] +pub struct ShortStateInfo { + pub shortstatehash: ShortStateHash, + pub full_state: Arc>, + pub added: Arc>, + pub removed: Arc>, +} + +#[derive(Clone, Default)] +pub struct HashSetCompressStateEvent { + pub shortstatehash: ShortStateHash, + pub added: Arc>, + pub removed: Arc>, +} + +pub type ShortStateHash = ShortId; +pub(crate) type CompressedStateEvent = [u8; 2 * size_of::()]; type StateInfoLruCache = LruCache; type ShortStateInfoVec = Vec; type ParentStatesVec = Vec; -type ShortStateInfo = ( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed -); - -type HashSetCompressStateEvent = (u64, Arc>, Arc>); -pub type CompressedStateEvent = [u8; 2 * size_of::()]; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -95,14 +104,19 @@ impl Service { if let Some(parent) = parent { let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?; - let mut state = (*response.last().expect("at least one response").1).clone(); + let mut state = (*response.last().expect("at least one response").full_state).clone(); state.extend(added.iter().copied()); let removed = (*removed).clone(); for r in &removed { state.remove(r); } - response.push((shortstatehash, Arc::new(state), added, Arc::new(removed))); + response.push(ShortStateInfo { + shortstatehash, + full_state: Arc::new(state), + added, + removed: Arc::new(removed), + }); self.stateinfo_cache .lock() @@ -111,7 +125,13 @@ impl Service { Ok(response) } else { - let response = vec![(shortstatehash, added.clone(), added, removed)]; + let response = vec![ShortStateInfo { + shortstatehash, + full_state: added.clone(), + added, + removed, + }]; + self.stateinfo_cache .lock() .expect("locked") @@ -185,8 +205,8 @@ impl Service { // To many layers, we have to go deeper let parent = parent_states.pop().expect("parent must have a state"); - let mut parent_new = (*parent.2).clone(); - let mut parent_removed = (*parent.3).clone(); + let mut parent_new = (*parent.added).clone(); + let mut parent_removed = (*parent.removed).clone(); for removed in statediffremoved.iter() { if !parent_new.remove(removed) { @@ -236,14 +256,14 @@ impl Service { // 2. We replace a layer above let parent = parent_states.pop().expect("parent must have a state"); - let parent_2_len = parent.2.len(); - let parent_3_len = parent.3.len(); - let parent_diff = checked!(parent_2_len + parent_3_len)?; + let parent_added_len = parent.added.len(); + let parent_removed_len = parent.removed.len(); + let parent_diff = checked!(parent_added_len + parent_removed_len)?; if checked!(diffsum * diffsum)? >= checked!(2 * diff_to_sibling * parent_diff)? { // Diff too big, we replace above layer(s) - let mut parent_new = (*parent.2).clone(); - let mut parent_removed = (*parent.3).clone(); + let mut parent_new = (*parent.added).clone(); + let mut parent_removed = (*parent.removed).clone(); for removed in statediffremoved.iter() { if !parent_new.remove(removed) { @@ -275,7 +295,7 @@ impl Service { self.save_statediff( shortstatehash, &StateDiff { - parent: Some(parent.0), + parent: Some(parent.shortstatehash), added: statediffnew, removed: statediffremoved, }, @@ -311,7 +331,10 @@ impl Service { .await; if Some(new_shortstatehash) == previous_shortstatehash { - return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new()))); + return Ok(HashSetCompressStateEvent { + shortstatehash: new_shortstatehash, + ..Default::default() + }); } let states_parents = if let Some(p) = previous_shortstatehash { @@ -322,12 +345,12 @@ impl Service { let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { let statediffnew: HashSet<_> = new_state_ids_compressed - .difference(&parent_stateinfo.1) + .difference(&parent_stateinfo.full_state) .copied() .collect(); let statediffremoved: HashSet<_> = parent_stateinfo - .1 + .full_state .difference(&new_state_ids_compressed) .copied() .collect(); @@ -347,7 +370,11 @@ impl Service { )?; }; - Ok((new_shortstatehash, statediffnew, statediffremoved)) + Ok(HashSetCompressStateEvent { + shortstatehash: new_shortstatehash, + added: statediffnew, + removed: statediffremoved, + }) } async fn get_statediff(&self, shortstatehash: u64) -> Result {