From 3789d60b6abf7d758bb75f898ccbaa7f1b4251aa Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Fri, 22 Nov 2024 12:25:46 +0000 Subject: [PATCH] refactor to iterator inputs for auth_chain/short batch functions Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 3 +- src/api/server/event_auth.rs | 4 +- src/api/server/send_join.rs | 6 +- src/api/server/state.rs | 4 +- src/api/server/state_ids.rs | 4 +- src/database/map/get.rs | 8 +- src/service/rooms/auth_chain/mod.rs | 32 +++++--- .../rooms/event_handler/resolve_state.rs | 4 +- .../rooms/event_handler/state_at_incoming.rs | 2 +- src/service/rooms/short/mod.rs | 80 +++++++++---------- 10 files changed, 76 insertions(+), 71 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index f9d4a521..89e47d4e 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -1,6 +1,7 @@ use std::{ collections::HashMap, fmt::Write, + iter::once, sync::Arc, time::{Instant, SystemTime}, }; @@ -43,7 +44,7 @@ pub(super) async fn get_auth_chain(&self, event_id: Box) -> Result = state_ids.values().map(Borrow::borrow).collect(); + let starting_events = state_ids.values().map(Borrow::borrow); let auth_chain = services .rooms .auth_chain - .event_ids_iter(room_id, &starting_events) + .event_ids_iter(room_id, starting_events) .await? .map(Ok) .and_then(|event_id| async move { services.rooms.timeline.get_pdu_json(&event_id).await }) diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 06a44a99..b21fce68 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,4 +1,4 @@ -use std::borrow::Borrow; +use std::{borrow::Borrow, iter::once}; use axum::extract::State; use conduit::{err, result::LogErr, utils::IterStream, Result}; @@ -52,7 +52,7 @@ pub(crate) async fn get_room_state_route( let auth_chain = services .rooms .auth_chain - .event_ids_iter(&body.room_id, &[body.event_id.borrow()]) + .event_ids_iter(&body.room_id, once(body.event_id.borrow())) .await? .map(Ok) .and_then(|id| async move { services.rooms.timeline.get_pdu_json(&id).await }) diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 52d8e7cc..0c023bf0 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,4 +1,4 @@ -use std::borrow::Borrow; +use std::{borrow::Borrow, iter::once}; use axum::extract::State; use conduit::{err, Result}; @@ -44,7 +44,7 @@ pub(crate) async fn get_room_state_ids_route( let auth_chain_ids = services .rooms .auth_chain - .event_ids_iter(&body.room_id, &[body.event_id.borrow()]) + .event_ids_iter(&body.room_id, once(body.event_id.borrow())) .await? .map(|id| (*id).to_owned()) .collect() diff --git a/src/database/map/get.rs b/src/database/map/get.rs index a3c6c492..3ee2a194 100644 --- a/src/database/map/get.rs +++ b/src/database/map/get.rs @@ -80,8 +80,8 @@ where #[tracing::instrument(skip(self, keys), fields(%self), level = "trace")] pub fn get_batch<'a, I, K>(&self, keys: I) -> impl Stream>> where - I: Iterator + ExactSizeIterator + Send + Debug, - K: AsRef<[u8]> + Send + Sync + Sized + Debug + 'a, + I: Iterator + ExactSizeIterator + Debug + Send, + K: AsRef<[u8]> + Debug + Send + ?Sized + Sync + 'a, { self.get_batch_blocking(keys).stream() } @@ -89,8 +89,8 @@ where #[implement(super::Map)] pub fn get_batch_blocking<'a, I, K>(&self, keys: I) -> impl Iterator>> where - I: Iterator + ExactSizeIterator + Send, - K: AsRef<[u8]> + Sized + 'a, + I: Iterator + ExactSizeIterator + Debug + Send, + K: AsRef<[u8]> + Debug + Send + ?Sized + Sync + 'a, { // Optimization can be `true` if key vector is pre-sorted **by the column // comparator**. diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index cabb6f0c..1d0490c2 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -2,6 +2,7 @@ mod data; use std::{ collections::{BTreeSet, HashSet}, + fmt::Debug, sync::Arc, }; @@ -37,9 +38,12 @@ impl crate::Service for Service { } impl Service { - pub async fn event_ids_iter( - &self, room_id: &RoomId, starting_events: &[&EventId], - ) -> Result> + Send + '_> { + pub async fn event_ids_iter<'a, I>( + &'a self, room_id: &RoomId, starting_events: I, + ) -> Result> + Send + '_> + where + I: Iterator + Clone + Debug + ExactSizeIterator + Send + 'a, + { let stream = self .get_event_ids(room_id, starting_events) .await? @@ -49,12 +53,15 @@ impl Service { Ok(stream) } - pub async fn get_event_ids(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result>> { + pub async fn get_event_ids<'a, I>(&'a self, room_id: &RoomId, starting_events: I) -> Result>> + where + I: Iterator + Clone + Debug + ExactSizeIterator + Send + 'a, + { let chain = self.get_auth_chain(room_id, starting_events).await?; let event_ids = self .services .short - .multi_get_eventid_from_short(&chain) + .multi_get_eventid_from_short(chain.into_iter()) .await .into_iter() .filter_map(Result::ok) @@ -64,7 +71,10 @@ impl Service { } #[tracing::instrument(skip_all, name = "auth_chain")] - pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result> { + pub async fn get_auth_chain<'a, I>(&'a self, room_id: &RoomId, starting_events: I) -> Result> + where + I: Iterator + Clone + Debug + ExactSizeIterator + Send + 'a, + { const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db? const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new(); @@ -72,19 +82,19 @@ impl Service { let mut starting_ids = self .services .short - .multi_get_or_create_shorteventid(starting_events) - .enumerate() + .multi_get_or_create_shorteventid(starting_events.clone()) + .zip(starting_events.clone().stream()) .boxed(); let mut buckets = [BUCKET; NUM_BUCKETS]; - while let Some((i, short)) = starting_ids.next().await { + while let Some((short, starting_event)) = starting_ids.next().await { let bucket: usize = short.try_into()?; let bucket: usize = validated!(bucket % NUM_BUCKETS); - buckets[bucket].insert((short, starting_events[i])); + buckets[bucket].insert((short, starting_event)); } debug!( - starting_events = ?starting_events.len(), + starting_events = ?starting_events.count(), elapsed = ?started.elapsed(), "start", ); diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs index 0c9525dd..4863e340 100644 --- a/src/service/rooms/event_handler/resolve_state.rs +++ b/src/service/rooms/event_handler/resolve_state.rs @@ -35,12 +35,12 @@ pub async fn resolve_state( let fork_states = [current_state_ids, incoming_state]; let mut auth_chain_sets = Vec::with_capacity(fork_states.len()); for state in &fork_states { - let starting_events: Vec<&EventId> = state.values().map(Borrow::borrow).collect(); + let starting_events = state.values().map(Borrow::borrow); let auth_chain: HashSet> = self .services .auth_chain - .get_event_ids(room_id, &starting_events) + .get_event_ids(room_id, starting_events) .await? .into_iter() .collect(); diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs index a200ab56..05a9d8ca 100644 --- a/src/service/rooms/event_handler/state_at_incoming.rs +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -139,7 +139,7 @@ pub(super) async fn state_at_incoming_resolved( let auth_chain: HashSet> = self .services .auth_chain - .get_event_ids(room_id, &starting_events) + .get_event_ids(room_id, starting_events.into_iter()) .await? .into_iter() .collect(); diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 703df796..e4ff2975 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,7 +1,7 @@ -use std::{mem::size_of_val, sync::Arc}; +use std::{fmt::Debug, mem::size_of_val, sync::Arc}; pub use conduit::pdu::{ShortEventId, ShortId, ShortRoomId}; -use conduit::{err, implement, utils, Result}; +use conduit::{err, implement, utils, utils::stream::ReadyExt, Result}; use database::{Deserialized, Map}; use futures::{Stream, StreamExt}; use ruma::{events::StateEventType, EventId, RoomId}; @@ -51,52 +51,46 @@ impl crate::Service for Service { #[implement(Service)] pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEventId { - const BUFSIZE: usize = size_of::(); - if let Ok(shorteventid) = self.get_shorteventid(event_id).await { return shorteventid; } - let shorteventid = self.services.globals.next_count().unwrap(); - debug_assert!(size_of_val(&shorteventid) == BUFSIZE, "buffer requirement changed"); - - self.db - .eventid_shorteventid - .raw_aput::(event_id, shorteventid); - - self.db - .shorteventid_eventid - .aput_raw::(shorteventid, event_id); - - shorteventid + self.create_shorteventid(event_id) } #[implement(Service)] -pub fn multi_get_or_create_shorteventid<'a>( - &'a self, event_ids: &'a [&EventId], -) -> impl Stream + Send + 'a { +pub fn multi_get_or_create_shorteventid<'a, I>(&'a self, event_ids: I) -> impl Stream + Send + '_ +where + I: Iterator + Clone + Debug + ExactSizeIterator + Send + 'a, + ::Item: AsRef<[u8]> + Send + Sync + 'a, +{ self.db .eventid_shorteventid - .get_batch(event_ids.iter()) - .enumerate() - .map(|(i, result)| match result { - Ok(ref short) => utils::u64_from_u8(short), - Err(_) => { - const BUFSIZE: usize = size_of::(); - - let short = self.services.globals.next_count().unwrap(); - debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); - - self.db - .eventid_shorteventid - .raw_aput::(event_ids[i], short); - self.db - .shorteventid_eventid - .aput_raw::(short, event_ids[i]); - - short - }, + .get_batch(event_ids.clone()) + .ready_scan(event_ids, |event_ids, result| { + event_ids.next().map(|event_id| (event_id, result)) }) + .map(|(event_id, result)| match result { + Ok(ref short) => utils::u64_from_u8(short), + Err(_) => self.create_shorteventid(event_id), + }) +} + +#[implement(Service)] +fn create_shorteventid(&self, event_id: &EventId) -> ShortEventId { + const BUFSIZE: usize = size_of::(); + + let short = self.services.globals.next_count().unwrap(); + debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); + + self.db + .eventid_shorteventid + .raw_aput::(event_id, short); + self.db + .shorteventid_eventid + .aput_raw::(short, event_id); + + short } #[implement(Service)] @@ -154,13 +148,13 @@ pub async fn get_eventid_from_short(&self, shorteventid: ShortEventId) -> Result } #[implement(Service)] -pub async fn multi_get_eventid_from_short(&self, shorteventid: &[ShortEventId]) -> Vec>> { +pub async fn multi_get_eventid_from_short(&self, shorteventid: I) -> Vec>> +where + I: Iterator + Send, +{ const BUFSIZE: usize = size_of::(); - let keys: Vec<[u8; BUFSIZE]> = shorteventid - .iter() - .map(|short| short.to_be_bytes()) - .collect(); + let keys: Vec<[u8; BUFSIZE]> = shorteventid.map(u64::to_be_bytes).collect(); self.db .shorteventid_eventid