diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs deleted file mode 100644 index cb020470..00000000 --- a/src/service/rooms/state_compressor/data.rs +++ /dev/null @@ -1,84 +0,0 @@ -use std::{collections::HashSet, mem::size_of, sync::Arc}; - -use conduit::{err, expected, utils, Result}; -use database::{Database, Map}; - -use super::CompressedStateEvent; - -pub(super) struct StateDiff { - pub(super) parent: Option, - pub(super) added: Arc>, - pub(super) removed: Arc>, -} - -pub(super) struct Data { - shortstatehash_statediff: Arc, -} - -impl Data { - pub(super) fn new(db: &Arc) -> Self { - Self { - shortstatehash_statediff: db["shortstatehash_statediff"].clone(), - } - } - - pub(super) async fn get_statediff(&self, shortstatehash: u64) -> Result { - const BUFSIZE: usize = size_of::(); - - let value = self - .shortstatehash_statediff - .aqry::(&shortstatehash) - .await - .map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?; - - let parent = utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); - let parent = if parent != 0 { - Some(parent) - } else { - None - }; - - let mut add_mode = true; - let mut added = HashSet::new(); - let mut removed = HashSet::new(); - - let stride = size_of::(); - let mut i = stride; - while let Some(v) = value.get(i..expected!(i + 2 * stride)) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i = expected!(i + stride); - continue; - } - if add_mode { - added.insert(v.try_into().expect("we checked the size above")); - } else { - removed.insert(v.try_into().expect("we checked the size above")); - } - i = expected!(i + 2 * stride); - } - - Ok(StateDiff { - parent, - added: Arc::new(added), - removed: Arc::new(removed), - }) - } - - pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) { - let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); - for new in diff.added.iter() { - value.extend_from_slice(&new[..]); - } - - if !diff.removed.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in diff.removed.iter() { - value.extend_from_slice(&removed[..]); - } - } - - self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value); - } -} diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index cd3f2f73..be66c597 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,53 +1,21 @@ -mod data; - use std::{ collections::HashSet, fmt::Write, mem::size_of, - sync::{Arc, Mutex as StdMutex, Mutex}, + sync::{Arc, Mutex}, }; -use conduit::{checked, utils, utils::math::usize_from_f64, Result}; -use data::Data; +use conduit::{checked, err, expected, utils, utils::math::usize_from_f64, Result}; +use database::Map; use lru_cache::LruCache; use ruma::{EventId, RoomId}; -use self::data::StateDiff; use crate::{rooms, Dep}; -type StateInfoLruCache = Mutex< - LruCache< - u64, - Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed - )>, - >, ->; - -type ShortStateInfoResult = Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed -)>; - -type ParentStatesVec = Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed -)>; - -type HashSetCompressStateEvent = (u64, Arc>, Arc>); -pub type CompressedStateEvent = [u8; 2 * size_of::()]; - pub struct Service { + pub stateinfo_cache: Mutex, db: Data, services: Services, - pub stateinfo_cache: StateInfoLruCache, } struct Services { @@ -55,17 +23,42 @@ struct Services { state: Dep, } +struct Data { + shortstatehash_statediff: Arc, +} + +struct StateDiff { + parent: Option, + added: Arc>, + removed: Arc>, +} + +type StateInfoLruCache = LruCache; +type ShortStateInfoVec = Vec; +type ParentStatesVec = Vec; +type ShortStateInfo = ( + u64, // sstatehash + Arc>, // full state + Arc>, // added + Arc>, // removed +); + +type HashSetCompressStateEvent = (u64, Arc>, Arc>); +pub type CompressedStateEvent = [u8; 2 * size_of::()]; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let config = &args.server.config; let cache_capacity = f64::from(config.stateinfo_cache_capacity) * config.cache_capacity_modifier; Ok(Arc::new(Self { - db: Data::new(args.db), + stateinfo_cache: LruCache::new(usize_from_f64(cache_capacity)?).into(), + db: Data { + shortstatehash_statediff: args.db["shortstatehash_statediff"].clone(), + }, services: Services { short: args.depend::("rooms::short"), state: args.depend::("rooms::state"), }, - stateinfo_cache: StdMutex::new(LruCache::new(usize_from_f64(cache_capacity)?)), })) } @@ -84,7 +77,7 @@ impl crate::Service for Service { impl Service { /// Returns a stack with info on shortstatehash, full state, added diff and /// removed diff for the selected shortstatehash and each parent layer. - pub async fn load_shortstatehash_info(&self, shortstatehash: u64) -> Result { + pub async fn load_shortstatehash_info(&self, shortstatehash: u64) -> Result { if let Some(r) = self .stateinfo_cache .lock() @@ -98,7 +91,7 @@ impl Service { parent, added, removed, - } = self.db.get_statediff(shortstatehash).await?; + } = self.get_statediff(shortstatehash).await?; if let Some(parent) = parent { let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?; @@ -177,12 +170,12 @@ impl Service { /// for this layer /// * `parent_states` - A stack with info on shortstatehash, full state, /// added diff and removed diff for each parent layer - #[tracing::instrument(skip(self, statediffnew, statediffremoved, diff_to_sibling, parent_states), level = "debug")] + #[tracing::instrument(skip_all, level = "debug")] pub fn save_state_from_diff( &self, shortstatehash: u64, statediffnew: Arc>, statediffremoved: Arc>, diff_to_sibling: usize, mut parent_states: ParentStatesVec, - ) -> Result<()> { + ) -> Result { let statediffnew_len = statediffnew.len(); let statediffremoved_len = statediffremoved.len(); let diffsum = checked!(statediffnew_len + statediffremoved_len)?; @@ -226,7 +219,7 @@ impl Service { if parent_states.is_empty() { // There is no parent layer, create a new state - self.db.save_statediff( + self.save_statediff( shortstatehash, &StateDiff { parent: None, @@ -279,7 +272,7 @@ impl Service { )?; } else { // Diff small enough, we add diff as layer on top of parent - self.db.save_statediff( + self.save_statediff( shortstatehash, &StateDiff { parent: Some(parent.0), @@ -324,7 +317,7 @@ impl Service { let states_parents = if let Some(p) = previous_shortstatehash { self.load_shortstatehash_info(p).await.unwrap_or_default() } else { - ShortStateInfoResult::new() + ShortStateInfoVec::new() }; let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { @@ -356,4 +349,63 @@ impl Service { Ok((new_shortstatehash, statediffnew, statediffremoved)) } + + async fn get_statediff(&self, shortstatehash: u64) -> Result { + const BUFSIZE: usize = size_of::(); + const STRIDE: usize = size_of::(); + + let value = self + .db + .shortstatehash_statediff + .aqry::(&shortstatehash) + .await + .map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?; + + let parent = utils::u64_from_bytes(&value[0..size_of::()]) + .ok() + .take_if(|parent| *parent != 0); + + let mut add_mode = true; + let mut added = HashSet::new(); + let mut removed = HashSet::new(); + + let mut i = STRIDE; + while let Some(v) = value.get(i..expected!(i + 2 * STRIDE)) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i = expected!(i + STRIDE); + continue; + } + if add_mode { + added.insert(v.try_into()?); + } else { + removed.insert(v.try_into()?); + } + i = expected!(i + 2 * STRIDE); + } + + Ok(StateDiff { + parent, + added: Arc::new(added), + removed: Arc::new(removed), + }) + } + + fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) { + let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); + for new in diff.added.iter() { + value.extend_from_slice(&new[..]); + } + + if !diff.removed.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in diff.removed.iter() { + value.extend_from_slice(&removed[..]); + } + } + + self.db + .shortstatehash_statediff + .insert(&shortstatehash.to_be_bytes(), &value); + } }