outdent auth_chain Service impl

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-01-20 09:05:49 +00:00
parent 4c0ae8c2f7
commit 610129d162

View file

@ -7,14 +7,14 @@ use std::{
}; };
use conduwuit::{ use conduwuit::{
at, debug, debug_error, trace, at, debug, debug_error, implement, trace,
utils::{ utils::{
stream::{ReadyExt, TryBroadbandExt}, stream::{ReadyExt, TryBroadbandExt},
IterStream, IterStream,
}, },
validated, warn, Err, Result, validated, warn, Err, Result,
}; };
use futures::{Stream, StreamExt, TryStreamExt}; use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{EventId, OwnedEventId, RoomId}; use ruma::{EventId, OwnedEventId, RoomId};
use self::data::Data; use self::data::Data;
@ -44,15 +44,15 @@ impl crate::Service for Service {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
} }
impl Service { #[implement(Service)]
pub async fn event_ids_iter<'a, I>( pub async fn event_ids_iter<'a, I>(
&'a self, &'a self,
room_id: &RoomId, room_id: &RoomId,
starting_events: I, starting_events: I,
) -> Result<impl Stream<Item = OwnedEventId> + Send + '_> ) -> Result<impl Stream<Item = OwnedEventId> + Send + '_>
where where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a, I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{ {
let stream = self let stream = self
.get_event_ids(room_id, starting_events) .get_event_ids(room_id, starting_events)
.await? .await?
@ -60,16 +60,17 @@ impl Service {
.stream(); .stream();
Ok(stream) Ok(stream)
} }
pub async fn get_event_ids<'a, I>( #[implement(Service)]
pub async fn get_event_ids<'a, I>(
&'a self, &'a self,
room_id: &RoomId, room_id: &RoomId,
starting_events: I, starting_events: I,
) -> Result<Vec<OwnedEventId>> ) -> Result<Vec<OwnedEventId>>
where where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a, I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{ {
let chain = self.get_auth_chain(room_id, starting_events).await?; let chain = self.get_auth_chain(room_id, starting_events).await?;
let event_ids = self let event_ids = self
.services .services
@ -80,17 +81,18 @@ impl Service {
.await; .await;
Ok(event_ids) Ok(event_ids)
} }
#[tracing::instrument(name = "auth_chain", level = "debug", skip_all)] #[implement(Service)]
pub async fn get_auth_chain<'a, I>( #[tracing::instrument(name = "auth_chain", level = "debug", skip_all)]
pub async fn get_auth_chain<'a, I>(
&'a self, &'a self,
room_id: &RoomId, room_id: &RoomId,
starting_events: I, starting_events: I,
) -> Result<Vec<ShortEventId>> ) -> Result<Vec<ShortEventId>>
where where
I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a, I: Iterator<Item = &'a EventId> + Clone + Debug + ExactSizeIterator + Send + 'a,
{ {
const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db? const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db?
const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new(); const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new();
@ -115,7 +117,7 @@ impl Service {
"start", "start",
); );
let full_auth_chain: Vec<_> = buckets let full_auth_chain: Vec<ShortEventId> = buckets
.into_iter() .into_iter()
.try_stream() .try_stream()
.broad_and_then(|chunk| async move { .broad_and_then(|chunk| async move {
@ -148,11 +150,14 @@ impl Service {
Ok(auth_chain) Ok(auth_chain)
}) })
.try_collect() .try_collect()
.await?; .map_ok(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect())
.map_ok(|mut chunk_cache: Vec<_>| {
let mut chunk_cache: Vec<_> = chunk_cache.into_iter().flatten().collect();
chunk_cache.sort_unstable(); chunk_cache.sort_unstable();
chunk_cache.dedup(); chunk_cache.dedup();
chunk_cache
})
.await?;
self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice()); self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice());
debug!( debug!(
chunk_cache_length = ?chunk_cache.len(), chunk_cache_length = ?chunk_cache.len(),
@ -163,11 +168,14 @@ impl Service {
Ok(chunk_cache) Ok(chunk_cache)
}) })
.try_collect() .try_collect()
.await?; .map_ok(|auth_chain: Vec<_>| auth_chain.into_iter().flatten().collect())
.map_ok(|mut full_auth_chain: Vec<_>| {
let mut full_auth_chain: Vec<_> = full_auth_chain.into_iter().flatten().collect();
full_auth_chain.sort_unstable(); full_auth_chain.sort_unstable();
full_auth_chain.dedup(); full_auth_chain.dedup();
full_auth_chain
})
.await?;
debug!( debug!(
chain_length = ?full_auth_chain.len(), chain_length = ?full_auth_chain.len(),
elapsed = ?started.elapsed(), elapsed = ?started.elapsed(),
@ -175,14 +183,15 @@ impl Service {
); );
Ok(full_auth_chain) Ok(full_auth_chain)
} }
#[tracing::instrument(name = "inner", level = "trace", skip(self, room_id))] #[implement(Service)]
async fn get_auth_chain_inner( #[tracing::instrument(name = "inner", level = "trace", skip(self, room_id))]
async fn get_auth_chain_inner(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event_id: &EventId, event_id: &EventId,
) -> Result<Vec<ShortEventId>> { ) -> Result<Vec<ShortEventId>> {
let mut todo: VecDeque<_> = [event_id.to_owned()].into(); let mut todo: VecDeque<_> = [event_id.to_owned()].into();
let mut found = HashSet::new(); let mut found = HashSet::new();
@ -211,11 +220,7 @@ impl Service {
.await; .await;
if found.insert(sauthevent) { if found.insert(sauthevent) {
trace!( trace!(?event_id, ?auth_event, "adding auth event to processing queue");
?event_id,
?auth_event,
"adding auth event to processing queue"
);
todo.push_back(auth_event.clone()); todo.push_back(auth_event.clone());
} }
@ -225,32 +230,36 @@ impl Service {
} }
Ok(found.into_iter().collect()) Ok(found.into_iter().collect())
} }
#[inline] #[implement(Service)]
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[ShortEventId]>> { #[inline]
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[ShortEventId]>> {
self.db.get_cached_eventid_authchain(key).await self.db.get_cached_eventid_authchain(key).await
} }
#[tracing::instrument(skip_all, level = "debug")] #[implement(Service)]
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<ShortEventId>) { #[tracing::instrument(skip_all, level = "debug")]
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<ShortEventId>) {
let val: Arc<[ShortEventId]> = auth_chain.iter().copied().collect(); let val: Arc<[ShortEventId]> = auth_chain.iter().copied().collect();
self.db.cache_auth_chain(key, val); self.db.cache_auth_chain(key, val);
} }
#[tracing::instrument(skip_all, level = "debug")] #[implement(Service)]
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &[ShortEventId]) { #[tracing::instrument(skip_all, level = "debug")]
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &[ShortEventId]) {
let val: Arc<[ShortEventId]> = auth_chain.iter().copied().collect(); let val: Arc<[ShortEventId]> = auth_chain.iter().copied().collect();
self.db.cache_auth_chain(key, val); self.db.cache_auth_chain(key, val);
} }
pub fn get_cache_usage(&self) -> (usize, usize) { #[implement(Service)]
pub fn get_cache_usage(&self) -> (usize, usize) {
let cache = self.db.auth_chain_cache.lock().expect("locked"); let cache = self.db.auth_chain_cache.lock().expect("locked");
(cache.len(), cache.capacity()) (cache.len(), cache.capacity())
}
pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().expect("locked").clear(); }
} }
#[implement(Service)]
pub fn clear_cache(&self) { self.db.auth_chain_cache.lock().expect("locked").clear(); }