diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs index 1dacb470..435a1c03 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/database/key_value/rooms/auth_chain.rs @@ -1,9 +1,9 @@ -use std::{collections::HashSet, mem::size_of, sync::Arc}; +use std::{mem::size_of, sync::Arc}; use crate::{database::KeyValueDatabase, service, utils, Result}; impl service::rooms::auth_chain::Data for KeyValueDatabase { - fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>>> { + fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { // Check RAM cache if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { return Ok(Some(Arc::clone(result))); @@ -19,12 +19,10 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase { chain .chunks_exact(size_of::()) .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) - .collect() + .collect::>() }); if let Some(chain) = chain { - let chain = Arc::new(chain); - // Cache in RAM self.auth_chain_cache .lock() @@ -38,7 +36,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase { Ok(None) } - fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { + fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) -> Result<()> { // Only persist single events in db if key.len() == 1 { self.shorteventid_authchain.insert( diff --git a/src/database/mod.rs b/src/database/mod.rs index 32b99a17..f4eef27a 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -181,7 +181,7 @@ pub struct KeyValueDatabase { //pub pusher: pusher::PushData, pub(super) senderkey_pusher: Arc, - pub(super) auth_chain_cache: Mutex, Arc>>>, + pub(super) auth_chain_cache: Mutex, Arc<[u64]>>>, pub(super) our_real_users_cache: RwLock>>>, pub(super) appservice_in_room_cache: RwLock>>, pub(super) lasttimelinecount_cache: Mutex>, diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index c83b4eb0..f77d2d90 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,8 +1,8 @@ -use std::{collections::HashSet, sync::Arc}; +use std::sync::Arc; use crate::Result; pub trait Data: Send + Sync { - fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result>>>; - fn cache_auth_chain(&self, shorteventid: Vec, auth_chain: Arc>) -> Result<()>; + fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result>>; + fn cache_auth_chain(&self, shorteventid: Vec, auth_chain: Arc<[u64]>) -> Result<()>; } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index b7f40695..03c49faf 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -15,15 +15,6 @@ pub struct Service { } impl Service { - pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>>> { - self.db.get_cached_eventid_authchain(key) - } - - #[tracing::instrument(skip(self))] - pub fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { - self.db.cache_auth_chain(key, auth_chain) - } - pub async fn get_auth_chain<'a>( &self, room_id: &RoomId, starting_events: Vec>, ) -> Result> + 'a> { @@ -81,7 +72,7 @@ impl Service { services() .rooms .auth_chain - .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; + .cache_auth_chain(vec![sevent_id], &auth_chain)?; debug!( event_id = ?event_id, chain_length = ?auth_chain.len(), @@ -105,7 +96,7 @@ impl Service { services() .rooms .auth_chain - .cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; + .cache_auth_chain(chunk_key, &chunk_cache)?; full_auth_chain.extend(chunk_cache.iter()); } @@ -154,4 +145,14 @@ impl Service { Ok(found) } + + pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>> { + self.db.get_cached_eventid_authchain(key) + } + + #[tracing::instrument(skip(self))] + pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) -> Result<()> { + self.db + .cache_auth_chain(key, auth_chain.iter().copied().collect::>()) + } }