From 784ccd6bad25f845532bc3a6e82c5031c50c6444 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 3 Dec 2024 09:42:26 +0000 Subject: [PATCH] return stream from multi_get_eventid_from_short Signed-off-by: Jason Volk --- src/service/rooms/auth_chain/mod.rs | 15 +++++---- src/service/rooms/short/mod.rs | 14 +++----- src/service/rooms/state/mod.rs | 13 +++---- src/service/rooms/state_accessor/data.rs | 43 +++++++++--------------- src/service/rooms/state_accessor/mod.rs | 10 +++--- 5 files changed, 39 insertions(+), 56 deletions(-) diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 5face0b5..e7e5edf4 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -6,7 +6,11 @@ use std::{ sync::Arc, }; -use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result}; +use conduit::{ + debug, debug_error, trace, + utils::{stream::ReadyExt, IterStream}, + validated, warn, Err, Result, +}; use futures::{Stream, StreamExt}; use ruma::{EventId, RoomId}; @@ -61,11 +65,10 @@ impl Service { let event_ids = self .services .short - .multi_get_eventid_from_short(chain.into_iter()) - .await - .into_iter() - .filter_map(Result::ok) - .collect(); + .multi_get_eventid_from_short(chain.iter()) + .ready_filter_map(Result::ok) + .collect() + .await; Ok(event_ids) } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 0f100348..a7c32856 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -139,7 +139,7 @@ pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &s #[implement(Service)] pub async fn get_eventid_from_short(&self, shorteventid: ShortEventId) -> Result where - Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + Id: for<'de> Deserialize<'de> + Sized + ToOwned, ::Owned: Borrow, { const BUFSIZE: usize = size_of::(); @@ -153,22 +153,18 @@ where } #[implement(Service)] -pub async fn multi_get_eventid_from_short(&self, shorteventid: I) -> Vec> +pub fn multi_get_eventid_from_short<'a, Id, I>(&'a self, shorteventid: I) -> impl Stream> + Send + 'a where - Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + I: Iterator + Send + 'a, + Id: for<'de> Deserialize<'de> + Sized + ToOwned + 'a, ::Owned: Borrow, - I: Iterator + Send, { const BUFSIZE: usize = size_of::(); - let keys: Vec<[u8; BUFSIZE]> = shorteventid.map(u64::to_be_bytes).collect(); - self.db .shorteventid_eventid - .get_batch(keys.iter()) + .aqry_batch::(shorteventid) .map(Deserialized::deserialized) - .collect() - .await } #[implement(Service)] diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 838deacd..d0d21fa8 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -6,7 +6,7 @@ use std::{ }; use conduit::{ - at, err, + err, result::FlatOk, utils::{ calculate_hash, @@ -420,7 +420,7 @@ impl Service { .collect() .await; - let auth_state: Vec<_> = self + let (state_keys, event_ids): (Vec<_>, Vec<_>) = self .services .state_accessor .state_full_shortids(shortstatehash) @@ -432,16 +432,13 @@ impl Service { .remove(&shortstatekey) .map(|(event_type, state_key)| ((event_type, state_key), shorteventid)) }) - .collect(); + .unzip(); let auth_pdus = self .services .short - .multi_get_eventid_from_short(auth_state.iter().map(at!(1))) - .await - .into_iter() - .stream() - .zip(auth_state.into_iter().stream().map(at!(0))) + .multi_get_eventid_from_short(event_ids.iter()) + .zip(state_keys.into_iter().stream()) .ready_filter_map(|(event_id, tsk)| Some((tsk, event_id.ok()?))) .broad_filter_map(|(tsk, event_id): (_, OwnedEventId)| async move { self.services diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 2a670066..bca54069 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -1,12 +1,12 @@ use std::{borrow::Borrow, collections::HashMap, sync::Arc}; use conduit::{ - at, err, - utils::stream::{BroadbandExt, IterStream}, + at, err, ref_at, + utils::stream::{BroadbandExt, IterStream, ReadyExt}, PduEvent, Result, }; use database::{Deserialized, Map}; -use futures::{StreamExt, TryFutureExt}; +use futures::{FutureExt, StreamExt, TryFutureExt}; use ruma::{events::StateEventType, EventId, OwnedEventId, RoomId}; use serde::Deserialize; @@ -59,23 +59,13 @@ impl Data { } pub(super) async fn state_full_pdus(&self, shortstatehash: ShortStateHash) -> Result> { - let short_ids = self - .state_full_shortids(shortstatehash) - .await? - .into_iter() - .map(at!(1)); + let short_ids = self.state_full_shortids(shortstatehash).await?; - let event_ids = self + let full_pdus = self .services .short - .multi_get_eventid_from_short(short_ids) - .await - .into_iter() - .filter_map(Result::ok); - - let full_pdus = event_ids - .into_iter() - .stream() + .multi_get_eventid_from_short(short_ids.iter().map(ref_at!(1))) + .ready_filter_map(Result::ok) .broad_filter_map( |event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await.ok() }, ) @@ -92,19 +82,16 @@ impl Data { { let short_ids = self.state_full_shortids(shortstatehash).await?; - let event_ids = self + let full_ids = self .services .short - .multi_get_eventid_from_short(short_ids.iter().map(at!(1))) + .multi_get_eventid_from_short(short_ids.iter().map(ref_at!(1))) + .zip(short_ids.iter().stream().map(at!(0))) + .ready_filter_map(|(event_id, shortstatekey)| Some((shortstatekey, event_id.ok()?))) + .collect() + .boxed() .await; - let full_ids = short_ids - .into_iter() - .map(at!(0)) - .zip(event_ids.into_iter()) - .filter_map(|(shortstatekey, event_id)| Some((shortstatekey, event_id.ok()?))) - .collect(); - Ok(full_ids) } @@ -134,7 +121,7 @@ impl Data { &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result where - Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + Id: for<'de> Deserialize<'de> + Sized + ToOwned, ::Owned: Borrow, { let shortstatekey = self @@ -219,7 +206,7 @@ impl Data { &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result where - Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + Id: for<'de> Deserialize<'de> + Sized + ToOwned, ::Owned: Borrow, { self.services diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index e42d3764..ef1b63f5 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -102,7 +102,7 @@ impl Service { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result> + pub async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result> where Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, ::Owned: Borrow, @@ -130,7 +130,7 @@ impl Service { &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result where - Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + Id: for<'de> Deserialize<'de> + Sized + ToOwned, ::Owned: Borrow, { self.db @@ -154,7 +154,7 @@ impl Service { &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result where - T: for<'de> Deserialize<'de> + Send, + T: for<'de> Deserialize<'de>, { self.state_get(shortstatehash, event_type, state_key) .await @@ -337,7 +337,7 @@ impl Service { &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result where - Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + Id: for<'de> Deserialize<'de> + Sized + ToOwned, ::Owned: Borrow, { self.db @@ -359,7 +359,7 @@ impl Service { &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, ) -> Result where - T: for<'de> Deserialize<'de> + Send, + T: for<'de> Deserialize<'de>, { self.room_state_get(room_id, event_type, state_key) .await