From 50acfe783289e6b9b8deb20b3c34f32653f61f11 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 29 Jan 2025 08:39:44 +0000 Subject: [PATCH] flatten auth chain iterations Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 7 +- src/api/server/event_auth.rs | 4 +- src/api/server/send_join.rs | 2 - src/api/server/state.rs | 2 - src/api/server/state_ids.rs | 8 +- src/service/rooms/auth_chain/mod.rs | 154 +++++++++--------- .../rooms/event_handler/resolve_state.rs | 15 +- .../rooms/event_handler/state_at_incoming.rs | 9 +- 8 files changed, 90 insertions(+), 111 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index cd892ded..4e0ce2e3 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -6,8 +6,9 @@ use std::{ }; use conduwuit::{ - debug_error, err, info, trace, utils, utils::string::EMPTY, warn, Error, PduEvent, PduId, - RawPduId, Result, + debug_error, err, info, trace, utils, + utils::{stream::ReadyExt, string::EMPTY}, + warn, Error, PduEvent, PduId, RawPduId, Result, }; use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::{ @@ -54,7 +55,7 @@ pub(super) async fn get_auth_chain( .rooms .auth_chain .event_ids_iter(room_id, once(event_id.as_ref())) - .await? + .ready_filter_map(Result::ok) .count() .await; diff --git a/src/api/server/event_auth.rs b/src/api/server/event_auth.rs index 93e867a0..49dcd718 100644 --- a/src/api/server/event_auth.rs +++ b/src/api/server/event_auth.rs @@ -1,7 +1,7 @@ use std::{borrow::Borrow, iter::once}; use axum::extract::State; -use conduwuit::{Error, Result}; +use conduwuit::{utils::stream::ReadyExt, Error, Result}; use futures::StreamExt; use ruma::{ api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, @@ -48,7 +48,7 @@ pub(crate) async fn get_event_authorization_route( .rooms .auth_chain .event_ids_iter(room_id, once(body.event_id.borrow())) - .await? + .ready_filter_map(Result::ok) .filter_map(|id| async move { services.rooms.timeline.get_pdu_json(&id).await.ok() }) .then(|pdu| services.sending.convert_to_outgoing_federation_event(pdu)) .collect() diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index 2b8a0eef..e81d7672 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -238,8 +238,6 @@ async fn create_join_event( .rooms .auth_chain .event_ids_iter(room_id, starting_events) - .await? - .map(Ok) .broad_and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await }) diff --git a/src/api/server/state.rs b/src/api/server/state.rs index eab1f138..b16e61a0 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -56,8 +56,6 @@ pub(crate) async fn get_room_state_route( .rooms .auth_chain .event_ids_iter(&body.room_id, once(body.event_id.borrow())) - .await? - .map(Ok) .and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await }) .and_then(|pdu| { services diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 4973dd3a..7d0440bf 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -2,7 +2,7 @@ use std::{borrow::Borrow, iter::once}; use axum::extract::State; use conduwuit::{at, err, Result}; -use futures::StreamExt; +use futures::{StreamExt, TryStreamExt}; use ruma::{api::federation::event::get_room_state_ids, OwnedEventId}; use super::AccessCheck; @@ -44,10 +44,8 @@ pub(crate) async fn get_room_state_ids_route( .rooms .auth_chain .event_ids_iter(&body.room_id, once(body.event_id.borrow())) - .await? - .map(|id| (*id).to_owned()) - .collect() - .await; + .try_collect() + .await?; Ok(get_room_state_ids::v1::Response { auth_chain_ids, pdu_ids }) } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index df2663b2..0ff96846 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -4,6 +4,7 @@ use std::{ collections::{BTreeSet, HashSet, VecDeque}, fmt::Debug, sync::Arc, + time::Instant, }; use conduwuit::{ @@ -14,7 +15,7 @@ use conduwuit::{ }, validated, warn, Err, Result, }; -use futures::{Stream, StreamExt, TryFutureExt, TryStreamExt}; +use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{EventId, OwnedEventId, RoomId}; use self::data::Data; @@ -30,6 +31,8 @@ struct Services { timeline: Dep, } +type Bucket<'a> = BTreeSet<(u64, &'a EventId)>; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { @@ -45,42 +48,22 @@ impl crate::Service for Service { } #[implement(Service)] -pub async fn event_ids_iter<'a, I>( +pub fn event_ids_iter<'a, I>( &'a self, - room_id: &RoomId, + room_id: &'a RoomId, starting_events: I, -) -> Result + Send + '_> +) -> impl Stream> + Send + 'a where I: Iterator + Clone + Debug + ExactSizeIterator + Send + 'a, { - let stream = self - .get_event_ids(room_id, starting_events) - .await? - .into_iter() - .stream(); - - Ok(stream) -} - -#[implement(Service)] -pub async fn get_event_ids<'a, I>( - &'a self, - room_id: &RoomId, - starting_events: I, -) -> Result> -where - I: Iterator + Clone + Debug + ExactSizeIterator + Send + 'a, -{ - let chain = self.get_auth_chain(room_id, starting_events).await?; - let event_ids = self - .services - .short - .multi_get_eventid_from_short(chain.into_iter().stream()) - .ready_filter_map(Result::ok) - .collect() - .await; - - Ok(event_ids) + self.get_auth_chain(room_id, starting_events) + .map_ok(|chain| { + self.services + .short + .multi_get_eventid_from_short(chain.into_iter().stream()) + .ready_filter(Result::is_ok) + }) + .try_flatten_stream() } #[implement(Service)] @@ -94,9 +77,9 @@ where I: Iterator + Clone + Debug + ExactSizeIterator + Send + 'a, { const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db? - const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new(); + const BUCKET: Bucket<'_> = BTreeSet::new(); - let started = std::time::Instant::now(); + let started = Instant::now(); let mut starting_ids = self .services .short @@ -120,53 +103,7 @@ where let full_auth_chain: Vec = buckets .into_iter() .try_stream() - .broad_and_then(|chunk| async move { - let chunk_key: Vec = chunk.iter().map(at!(0)).collect(); - - if chunk_key.is_empty() { - return Ok(Vec::new()); - } - - if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await { - return Ok(cached.to_vec()); - } - - let chunk_cache: Vec<_> = chunk - .into_iter() - .try_stream() - .broad_and_then(|(shortid, event_id)| async move { - if let Ok(cached) = self.get_cached_eventid_authchain(&[shortid]).await { - return Ok(cached.to_vec()); - } - - let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?; - self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice()); - debug!( - ?event_id, - elapsed = ?started.elapsed(), - "Cache missed event" - ); - - Ok(auth_chain) - }) - .try_collect() - .map_ok(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect()) - .map_ok(|mut chunk_cache: Vec<_>| { - chunk_cache.sort_unstable(); - chunk_cache.dedup(); - chunk_cache - }) - .await?; - - self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice()); - debug!( - chunk_cache_length = ?chunk_cache.len(), - elapsed = ?started.elapsed(), - "Cache missed chunk", - ); - - Ok(chunk_cache) - }) + .broad_and_then(|chunk| self.get_auth_chain_outer(room_id, started, chunk)) .try_collect() .map_ok(|auth_chain: Vec<_>| auth_chain.into_iter().flatten().collect()) .map_ok(|mut full_auth_chain: Vec<_>| { @@ -174,6 +111,7 @@ where full_auth_chain.dedup(); full_auth_chain }) + .boxed() .await?; debug!( @@ -185,6 +123,60 @@ where Ok(full_auth_chain) } +#[implement(Service)] +async fn get_auth_chain_outer( + &self, + room_id: &RoomId, + started: Instant, + chunk: Bucket<'_>, +) -> Result> { + let chunk_key: Vec = chunk.iter().map(at!(0)).collect(); + + if chunk_key.is_empty() { + return Ok(Vec::new()); + } + + if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await { + return Ok(cached.to_vec()); + } + + let chunk_cache: Vec<_> = chunk + .into_iter() + .try_stream() + .broad_and_then(|(shortid, event_id)| async move { + if let Ok(cached) = self.get_cached_eventid_authchain(&[shortid]).await { + return Ok(cached.to_vec()); + } + + let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?; + self.cache_auth_chain_vec(vec![shortid], auth_chain.as_slice()); + debug!( + ?event_id, + elapsed = ?started.elapsed(), + "Cache missed event" + ); + + Ok(auth_chain) + }) + .try_collect() + .map_ok(|chunk_cache: Vec<_>| chunk_cache.into_iter().flatten().collect()) + .map_ok(|mut chunk_cache: Vec<_>| { + chunk_cache.sort_unstable(); + chunk_cache.dedup(); + chunk_cache + }) + .await?; + + self.cache_auth_chain_vec(chunk_key, chunk_cache.as_slice()); + debug!( + chunk_cache_length = ?chunk_cache.len(), + elapsed = ?started.elapsed(), + "Cache missed chunk", + ); + + Ok(chunk_cache) +} + #[implement(Service)] #[tracing::instrument(name = "inner", level = "trace", skip(self, room_id))] async fn get_auth_chain_inner( diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs index 1fd91ac6..03f7e822 100644 --- a/src/service/rooms/event_handler/resolve_state.rs +++ b/src/service/rooms/event_handler/resolve_state.rs @@ -44,18 +44,11 @@ pub async fn resolve_state( let auth_chain_sets: Vec> = fork_states .iter() .try_stream() - .wide_and_then(|state| async move { - let starting_events = state.values().map(Borrow::borrow); - - let auth_chain = self - .services + .wide_and_then(|state| { + self.services .auth_chain - .get_event_ids(room_id, starting_events) - .await? - .into_iter() - .collect(); - - Ok(auth_chain) + .event_ids_iter(room_id, state.values().map(Borrow::borrow)) + .try_collect() }) .try_collect() .await?; diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs index 7ef047ab..8730232a 100644 --- a/src/service/rooms/event_handler/state_at_incoming.rs +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -10,7 +10,7 @@ use conduwuit::{ utils::stream::{BroadbandExt, IterStream}, PduEvent, Result, }; -use futures::{FutureExt, StreamExt}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::{state_res::StateMap, OwnedEventId, RoomId, RoomVersionId}; // TODO: if we know the prev_events of the incoming event we can avoid the @@ -140,10 +140,9 @@ pub(super) async fn state_at_incoming_resolved( let auth_chain: HashSet = self .services .auth_chain - .get_event_ids(room_id, starting_events.into_iter()) - .await? - .into_iter() - .collect(); + .event_ids_iter(room_id, starting_events.into_iter()) + .try_collect() + .await?; auth_chain_sets.push(auth_chain); fork_states.push(state);