From 36677bb9828038294d06f2292eef755139216c40 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 1 Oct 2024 23:19:47 +0000 Subject: [PATCH] optimize auth_chain short_id to event_id translation step Signed-off-by: Jason Volk --- src/service/rooms/auth_chain/mod.rs | 30 ++++++++++++++++++-------- src/service/rooms/event_handler/mod.rs | 16 +++++++------- src/service/rooms/short/mod.rs | 17 +++++++++++++++ 3 files changed, 46 insertions(+), 17 deletions(-) diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index eae13b74..f3861ca3 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -6,7 +6,7 @@ use std::{ }; use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result}; -use futures::{FutureExt, Stream, StreamExt}; +use futures::Stream; use ruma::{EventId, RoomId}; use self::data::Data; @@ -40,15 +40,27 @@ impl Service { pub async fn event_ids_iter( &self, room_id: &RoomId, starting_events: &[&EventId], ) -> Result> + Send + '_> { - let chain = self.get_auth_chain(room_id, starting_events).await?; - let iter = chain.into_iter().stream().filter_map(|sid| { - self.services - .short - .get_eventid_from_short(sid) - .map(Result::ok) - }); + let stream = self + .get_event_ids(room_id, starting_events) + .await? + .into_iter() + .stream(); - Ok(iter) + Ok(stream) + } + + pub async fn get_event_ids(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result>> { + let chain = self.get_auth_chain(room_id, starting_events).await?; + let event_ids = self + .services + .short + .multi_get_eventid_from_short(&chain) + .await + .into_iter() + .filter_map(Result::ok) + .collect(); + + Ok(event_ids) } #[tracing::instrument(skip_all, name = "auth_chain")] diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 57b87706..4708a86c 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -797,13 +797,13 @@ impl Service { for state in &fork_states { let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect(); - let auth_chain = self + let auth_chain: HashSet> = self .services .auth_chain - .event_ids_iter(room_id, &starting_events) + .get_event_ids(room_id, &starting_events) .await? - .collect::>>() - .await; + .into_iter() + .collect(); auth_chain_sets.push(auth_chain); } @@ -983,13 +983,13 @@ impl Service { starting_events.push(id.borrow()); } - let auth_chain = self + let auth_chain: HashSet> = self .services .auth_chain - .event_ids_iter(room_id, &starting_events) + .get_event_ids(room_id, &starting_events) .await? - .collect() - .await; + .into_iter() + .collect(); auth_chain_sets.push(auth_chain); fork_states.push(state); diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 825ee109..20082da2 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -141,6 +141,23 @@ pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result Vec>> { + const BUFSIZE: usize = size_of::(); + + let keys: Vec<[u8; BUFSIZE]> = shorteventid + .iter() + .map(|short| short.to_be_bytes()) + .collect(); + + self.db + .shorteventid_eventid + .get_batch_blocking(keys.iter()) + .into_iter() + .map(Deserialized::deserialized) + .collect() +} + #[implement(Service)] pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { const BUFSIZE: usize = size_of::();