From af399fd5179eed9c72bf0426858301af9ffc92d4 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 29 Jan 2025 01:04:02 +0000 Subject: [PATCH] flatten state accessor iterations Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 11 +- src/api/client/context.rs | 28 ++-- src/api/client/membership.rs | 18 +- src/api/client/message.rs | 6 +- src/api/client/room/initial_sync.rs | 9 +- src/api/client/search.rs | 12 +- src/api/client/state.rs | 12 +- src/api/client/sync/v3.rs | 48 +++--- src/api/client/sync/v4.rs | 8 +- src/api/client/sync/v5.rs | 8 +- src/api/server/send_join.rs | 14 +- src/api/server/state.rs | 10 +- src/api/server/state_ids.rs | 9 +- src/core/pdu/strip.rs | 12 +- .../rooms/event_handler/resolve_state.rs | 5 +- .../rooms/event_handler/state_at_incoming.rs | 17 +- src/service/rooms/spaces/mod.rs | 4 +- src/service/rooms/state_accessor/mod.rs | 155 ++++++++++-------- 18 files changed, 205 insertions(+), 181 deletions(-) diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index cdd69c0f..cd892ded 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -9,7 +9,7 @@ use conduwuit::{ debug_error, err, info, trace, utils, utils::string::EMPTY, warn, Error, PduEvent, PduId, RawPduId, Result, }; -use futures::{FutureExt, StreamExt}; +use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::{ api::{client::error::ErrorKind, federation::event::get_room_state}, events::room::message::RoomMessageEventContent, @@ -327,11 +327,10 @@ pub(super) async fn get_room_state( .services .rooms .state_accessor - .room_state_full(&room_id) - .await? - .values() - .map(PduEvent::to_state_event) - .collect(); + .room_state_full_pdus(&room_id) + .map_ok(PduEvent::into_state_event) + .try_collect() + .await?; if room_state.is_empty() { return Ok(RoomMessageEventContent::text_plain( diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 388bcf4d..7256683f 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -1,6 +1,6 @@ use axum::extract::State; use conduwuit::{ - at, deref_at, err, ref_at, + at, err, ref_at, utils::{ future::TryExtExt, stream::{BroadbandExt, ReadyExt, TryIgnore, WidebandExt}, @@ -10,10 +10,10 @@ use conduwuit::{ }; use futures::{ future::{join, join3, try_join3, OptionFuture}, - FutureExt, StreamExt, TryFutureExt, + FutureExt, StreamExt, TryFutureExt, TryStreamExt, }; use ruma::{api::client::context::get_context, events::StateEventType, OwnedEventId, UserId}; -use service::rooms::{lazy_loading, lazy_loading::Options}; +use service::rooms::{lazy_loading, lazy_loading::Options, short::ShortStateKey}; use crate::{ client::message::{event_filter, ignored_filter, lazy_loading_witness, visibility_filter}, @@ -132,21 +132,29 @@ pub(crate) async fn get_context_route( .state_accessor .pdu_shortstatehash(state_at) .or_else(|_| services.rooms.state.get_room_shortstatehash(room_id)) - .and_then(|shortstatehash| services.rooms.state_accessor.state_full_ids(shortstatehash)) + .map_ok(|shortstatehash| { + services + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .map(Ok) + }) .map_err(|e| err!(Database("State not found: {e}"))) + .try_flatten_stream() + .try_collect() .boxed(); let (lazy_loading_witnessed, state_ids) = join(lazy_loading_witnessed, state_ids).await; - let state_ids = state_ids?; + let state_ids: Vec<(ShortStateKey, OwnedEventId)> = state_ids?; + let shortstatekeys = state_ids.iter().map(at!(0)).stream(); + let shorteventids = state_ids.iter().map(ref_at!(1)).stream(); let lazy_loading_witnessed = lazy_loading_witnessed.unwrap_or_default(); - let shortstatekeys = state_ids.iter().stream().map(deref_at!(0)); - let state: Vec<_> = services .rooms .short .multi_get_statekey_from_short(shortstatekeys) - .zip(state_ids.iter().stream().map(at!(1))) + .zip(shorteventids) .ready_filter_map(|item| Some((item.0.ok()?, item.1))) .ready_filter_map(|((event_type, state_key), event_id)| { if filter.lazy_load_options.is_enabled() @@ -162,9 +170,9 @@ pub(crate) async fn get_context_route( Some(event_id) }) .broad_filter_map(|event_id: &OwnedEventId| { - services.rooms.timeline.get_pdu(event_id).ok() + services.rooms.timeline.get_pdu(event_id.as_ref()).ok() }) - .map(|pdu| pdu.to_state_event()) + .map(PduEvent::into_state_event) .collect() .await; diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 2e23dab9..fccb9b53 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -8,14 +8,14 @@ use std::{ use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduwuit::{ - debug, debug_info, debug_warn, err, info, + at, debug, debug_info, debug_warn, err, info, pdu::{gen_event_id_canonical_json, PduBuilder}, result::FlatOk, trace, utils::{self, shuffle, IterStream, ReadyExt}, warn, Err, PduEvent, Result, }; -use futures::{join, FutureExt, StreamExt}; +use futures::{join, FutureExt, StreamExt, TryFutureExt}; use ruma::{ api::{ client::{ @@ -765,11 +765,12 @@ pub(crate) async fn get_member_events_route( .rooms .state_accessor .room_state_full(&body.room_id) - .await? - .iter() - .filter(|(key, _)| key.0 == StateEventType::RoomMember) - .map(|(_, pdu)| pdu.to_member_event()) - .collect(), + .ready_filter_map(Result::ok) + .ready_filter(|((ty, _), _)| *ty == StateEventType::RoomMember) + .map(at!(1)) + .map(PduEvent::into_member_event) + .collect() + .await, }) } @@ -1707,9 +1708,6 @@ pub async fn leave_room( room_id: &RoomId, reason: Option, ) -> Result<()> { - //use conduwuit::utils::stream::OptionStream; - use futures::TryFutureExt; - // Ask a remote server if we don't have this room and are not knocking on it if !services .rooms diff --git a/src/api/client/message.rs b/src/api/client/message.rs index a508b5da..321d8013 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -6,9 +6,9 @@ use conduwuit::{ stream::{BroadbandExt, TryIgnore, WidebandExt}, IterStream, ReadyExt, }, - Event, PduCount, Result, + Event, PduCount, PduEvent, Result, }; -use futures::{future::OptionFuture, pin_mut, FutureExt, StreamExt}; +use futures::{future::OptionFuture, pin_mut, FutureExt, StreamExt, TryFutureExt}; use ruma::{ api::{ client::{filter::RoomEventFilter, message::get_message_events}, @@ -220,8 +220,8 @@ async fn get_member_event( .rooms .state_accessor .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str()) + .map_ok(PduEvent::into_state_event) .await - .map(|member_event| member_event.to_state_event()) .ok() } diff --git a/src/api/client/room/initial_sync.rs b/src/api/client/room/initial_sync.rs index 301b6e8d..233d180f 100644 --- a/src/api/client/room/initial_sync.rs +++ b/src/api/client/room/initial_sync.rs @@ -2,7 +2,7 @@ use axum::extract::State; use conduwuit::{ at, utils::{stream::TryTools, BoolExt}, - Err, Result, + Err, PduEvent, Result, }; use futures::TryStreamExt; use ruma::api::client::room::initial_sync::v3::{PaginationChunk, Request, Response}; @@ -39,10 +39,9 @@ pub(crate) async fn room_initial_sync_route( .rooms .state_accessor .room_state_full_pdus(room_id) - .await? - .into_iter() - .map(|pdu| pdu.to_state_event()) - .collect(); + .map_ok(PduEvent::into_state_event) + .try_collect() + .await?; let messages = PaginationChunk { start: events.last().map(at!(0)).as_ref().map(ToString::to_string), diff --git a/src/api/client/search.rs b/src/api/client/search.rs index e60bd26d..898dfc7f 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -7,7 +7,7 @@ use conduwuit::{ utils::{stream::ReadyExt, IterStream}, Err, PduEvent, Result, }; -use futures::{future::OptionFuture, FutureExt, StreamExt, TryFutureExt}; +use futures::{future::OptionFuture, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{ api::client::search::search_events::{ self, @@ -181,15 +181,15 @@ async fn category_room_events( } async fn procure_room_state(services: &Services, room_id: &RoomId) -> Result { - let state_map = services + let state = services .rooms .state_accessor - .room_state_full(room_id) + .room_state_full_pdus(room_id) + .map_ok(PduEvent::into_state_event) + .try_collect() .await?; - let state_events = state_map.values().map(PduEvent::to_state_event).collect(); - - Ok(state_events) + Ok(state) } async fn check_room_visible( diff --git a/src/api/client/state.rs b/src/api/client/state.rs index d00ee5e5..8555f88b 100644 --- a/src/api/client/state.rs +++ b/src/api/client/state.rs @@ -1,5 +1,6 @@ use axum::extract::State; use conduwuit::{err, pdu::PduBuilder, utils::BoolExt, Err, PduEvent, Result}; +use futures::TryStreamExt; use ruma::{ api::client::state::{get_state_events, get_state_events_for_key, send_state_event}, events::{ @@ -82,11 +83,10 @@ pub(crate) async fn get_state_events_route( room_state: services .rooms .state_accessor - .room_state_full(&body.room_id) - .await? - .values() - .map(PduEvent::to_state_event) - .collect(), + .room_state_full_pdus(&body.room_id) + .map_ok(PduEvent::into_state_event) + .try_collect() + .await?, }) } @@ -133,7 +133,7 @@ pub(crate) async fn get_state_events_for_key_route( Ok(get_state_events_for_key::v3::Response { content: event_format.or(|| event.get_content_as_value()), - event: event_format.then(|| event.to_state_event_value()), + event: event_format.then(|| event.into_state_event_value()), }) } diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 7cca9616..cd4dfc90 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -28,7 +28,7 @@ use conduwuit_service::{ }; use futures::{ future::{join, join3, join4, join5, try_join, try_join4, OptionFuture}, - FutureExt, StreamExt, TryFutureExt, + FutureExt, StreamExt, TryFutureExt, TryStreamExt, }; use ruma::{ api::client::{ @@ -503,16 +503,20 @@ async fn handle_left_room( let mut left_state_events = Vec::new(); - let since_shortstatehash = services - .rooms - .user - .get_token_shortstatehash(room_id, since) - .await; + let since_shortstatehash = services.rooms.user.get_token_shortstatehash(room_id, since); - let since_state_ids = match since_shortstatehash { - | Ok(s) => services.rooms.state_accessor.state_full_ids(s).await?, - | Err(_) => HashMap::new(), - }; + let since_state_ids: HashMap<_, OwnedEventId> = since_shortstatehash + .map_ok(|since_shortstatehash| { + services + .rooms + .state_accessor + .state_full_ids(since_shortstatehash) + .map(Ok) + }) + .try_flatten_stream() + .try_collect() + .await + .unwrap_or_default(); let Ok(left_event_id): Result = services .rooms @@ -534,11 +538,12 @@ async fn handle_left_room( return Ok(None); }; - let mut left_state_ids = services + let mut left_state_ids: HashMap<_, _> = services .rooms .state_accessor .state_full_ids(left_shortstatehash) - .await?; + .collect() + .await; let leave_shortstatekey = services .rooms @@ -960,19 +965,18 @@ async fn calculate_state_initial( current_shortstatehash: ShortStateHash, witness: Option<&Witness>, ) -> Result { - let state_events = services + let (shortstatekeys, event_ids): (Vec<_>, Vec<_>) = services .rooms .state_accessor .state_full_ids(current_shortstatehash) - .await?; - - let shortstatekeys = state_events.keys().copied().stream(); + .unzip() + .await; let state_events = services .rooms .short - .multi_get_statekey_from_short(shortstatekeys) - .zip(state_events.values().cloned().stream()) + .multi_get_statekey_from_short(shortstatekeys.into_iter().stream()) + .zip(event_ids.into_iter().stream()) .ready_filter_map(|item| Some((item.0.ok()?, item.1))) .ready_filter_map(|((event_type, state_key), event_id)| { let lazy_load_enabled = filter.room.state.lazy_load_options.is_enabled() @@ -1036,17 +1040,19 @@ async fn calculate_state_incremental( let current_state_ids = services .rooms .state_accessor - .state_full_ids(current_shortstatehash); + .state_full_ids(current_shortstatehash) + .collect(); let since_state_ids = services .rooms .state_accessor - .state_full_ids(since_shortstatehash); + .state_full_ids(since_shortstatehash) + .collect(); let (current_state_ids, since_state_ids): ( HashMap<_, OwnedEventId>, HashMap<_, OwnedEventId>, - ) = try_join(current_state_ids, since_state_ids).await?; + ) = join(current_state_ids, since_state_ids).await; current_state_ids .iter() diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index a82e9309..b7967498 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -241,13 +241,15 @@ pub(crate) async fn sync_events_v4_route( .rooms .state_accessor .state_full_ids(current_shortstatehash) - .await?; + .collect() + .await; - let since_state_ids = services + let since_state_ids: HashMap<_, _> = services .rooms .state_accessor .state_full_ids(since_shortstatehash) - .await?; + .collect() + .await; for (key, id) in current_state_ids { if since_state_ids.get(&key) != Some(&id) { diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index 1c4f3504..66647f0e 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -748,13 +748,15 @@ async fn collect_e2ee<'a>( .rooms .state_accessor .state_full_ids(current_shortstatehash) - .await?; + .collect() + .await; - let since_state_ids = services + let since_state_ids: HashMap<_, _> = services .rooms .state_accessor .state_full_ids(since_shortstatehash) - .await?; + .collect() + .await; for (key, id) in current_state_ids { if since_state_ids.get(&key) != Some(&id) { diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index e62089b4..2b8a0eef 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -1,10 +1,10 @@ #![allow(deprecated)] -use std::{borrow::Borrow, collections::HashMap}; +use std::borrow::Borrow; use axum::extract::State; use conduwuit::{ - err, + at, err, pdu::gen_event_id_canonical_json, utils::stream::{IterStream, TryBroadbandExt}, warn, Err, Result, @@ -211,14 +211,16 @@ async fn create_join_event( drop(mutex_lock); - let state_ids: HashMap<_, OwnedEventId> = services + let state_ids: Vec = services .rooms .state_accessor .state_full_ids(shortstatehash) - .await?; + .map(at!(1)) + .collect() + .await; let state = state_ids - .values() + .iter() .try_stream() .broad_and_then(|event_id| services.rooms.timeline.get_pdu_json(event_id)) .broad_and_then(|pdu| { @@ -231,7 +233,7 @@ async fn create_join_event( .boxed() .await?; - let starting_events = state_ids.values().map(Borrow::borrow); + let starting_events = state_ids.iter().map(Borrow::borrow); let auth_chain = services .rooms .auth_chain diff --git a/src/api/server/state.rs b/src/api/server/state.rs index 42f7e538..eab1f138 100644 --- a/src/api/server/state.rs +++ b/src/api/server/state.rs @@ -1,7 +1,7 @@ use std::{borrow::Borrow, iter::once}; use axum::extract::State; -use conduwuit::{err, result::LogErr, utils::IterStream, Result}; +use conduwuit::{at, err, utils::IterStream, Result}; use futures::{FutureExt, StreamExt, TryStreamExt}; use ruma::{api::federation::event::get_room_state, OwnedEventId}; @@ -35,11 +35,9 @@ pub(crate) async fn get_room_state_route( .rooms .state_accessor .state_full_ids(shortstatehash) - .await - .log_err() - .map_err(|_| err!(Request(NotFound("PDU state IDs not found."))))? - .into_values() - .collect(); + .map(at!(1)) + .collect() + .await; let pdus = state_ids .iter() diff --git a/src/api/server/state_ids.rs b/src/api/server/state_ids.rs index 186ef399..4973dd3a 100644 --- a/src/api/server/state_ids.rs +++ b/src/api/server/state_ids.rs @@ -1,7 +1,7 @@ use std::{borrow::Borrow, iter::once}; use axum::extract::State; -use conduwuit::{err, Result}; +use conduwuit::{at, err, Result}; use futures::StreamExt; use ruma::{api::federation::event::get_room_state_ids, OwnedEventId}; @@ -36,10 +36,9 @@ pub(crate) async fn get_room_state_ids_route( .rooms .state_accessor .state_full_ids(shortstatehash) - .await - .map_err(|_| err!(Request(NotFound("State ids not found"))))? - .into_values() - .collect(); + .map(at!(1)) + .collect() + .await; let auth_chain_ids = services .rooms diff --git a/src/core/pdu/strip.rs b/src/core/pdu/strip.rs index 8e1045db..7d2fb1d6 100644 --- a/src/core/pdu/strip.rs +++ b/src/core/pdu/strip.rs @@ -116,7 +116,7 @@ pub fn to_message_like_event(&self) -> Raw { #[must_use] #[implement(super::Pdu)] -pub fn to_state_event_value(&self) -> JsonValue { +pub fn into_state_event_value(self) -> JsonValue { let mut json = json!({ "content": self.content, "type": self.kind, @@ -127,7 +127,7 @@ pub fn to_state_event_value(&self) -> JsonValue { "state_key": self.state_key, }); - if let Some(unsigned) = &self.unsigned { + if let Some(unsigned) = self.unsigned { json["unsigned"] = json!(unsigned); } @@ -136,8 +136,8 @@ pub fn to_state_event_value(&self) -> JsonValue { #[must_use] #[implement(super::Pdu)] -pub fn to_state_event(&self) -> Raw { - serde_json::from_value(self.to_state_event_value()).expect("Raw::from_value always works") +pub fn into_state_event(self) -> Raw { + serde_json::from_value(self.into_state_event_value()).expect("Raw::from_value always works") } #[must_use] @@ -188,7 +188,7 @@ pub fn to_stripped_spacechild_state_event(&self) -> Raw Raw> { +pub fn into_member_event(self) -> Raw> { let mut json = json!({ "content": self.content, "type": self.kind, @@ -200,7 +200,7 @@ pub fn to_member_event(&self) -> Raw> { "state_key": self.state_key, }); - if let Some(unsigned) = &self.unsigned { + if let Some(unsigned) = self.unsigned { json["unsigned"] = json!(unsigned); } diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs index 0526d31c..1fd91ac6 100644 --- a/src/service/rooms/event_handler/resolve_state.rs +++ b/src/service/rooms/event_handler/resolve_state.rs @@ -33,11 +33,12 @@ pub async fn resolve_state( .await .map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?; - let current_state_ids = self + let current_state_ids: HashMap<_, _> = self .services .state_accessor .state_full_ids(current_sstatehash) - .await?; + .collect() + .await; let fork_states = [current_state_ids, incoming_state]; let auth_chain_sets: Vec> = fork_states diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs index 9e7f8d2a..7ef047ab 100644 --- a/src/service/rooms/event_handler/state_at_incoming.rs +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -31,15 +31,12 @@ pub(super) async fn state_at_incoming_degree_one( return Ok(None); }; - let Ok(mut state) = self + let mut state: HashMap<_, _> = self .services .state_accessor .state_full_ids(prev_event_sstatehash) - .await - .log_err() - else { - return Ok(None); - }; + .collect() + .await; debug!("Using cached state"); let prev_pdu = self @@ -103,14 +100,12 @@ pub(super) async fn state_at_incoming_resolved( let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); for (sstatehash, prev_event) in extremity_sstatehashes { - let Ok(mut leaf_state) = self + let mut leaf_state: HashMap<_, _> = self .services .state_accessor .state_full_ids(sstatehash) - .await - else { - continue; - }; + .collect() + .await; if let Some(state_key) = &prev_event.state_key { let shortstatekey = self diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index d60c4c9e..d12a01ab 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -624,8 +624,8 @@ impl Service { .services .state_accessor .state_full_ids(current_shortstatehash) - .await - .map_err(|e| err!(Database("State in space not found: {e}")))?; + .collect() + .await; let mut children_pdus = Vec::with_capacity(state.len()); for (key, id) in state { diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 3d87534b..0f5520bb 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -1,6 +1,5 @@ use std::{ borrow::Borrow, - collections::HashMap, fmt::Write, sync::{Arc, Mutex as StdMutex, Mutex}, }; @@ -17,7 +16,7 @@ use conduwuit::{ Err, Error, PduEvent, Result, }; use database::{Deserialized, Map}; -use futures::{FutureExt, StreamExt, TryFutureExt}; +use futures::{FutureExt, Stream, StreamExt, TryFutureExt}; use lru_cache::LruCache; use ruma::{ events::{ @@ -143,83 +142,74 @@ impl crate::Service for Service { } impl Service { - pub async fn state_full( + pub fn state_full( &self, shortstatehash: ShortStateHash, - ) -> Result> { - let state = self - .state_full_pdus(shortstatehash) - .await? - .into_iter() - .filter_map(|pdu| Some(((pdu.kind.to_string().into(), pdu.state_key.clone()?), pdu))) - .collect(); - - Ok(state) + ) -> impl Stream + Send + '_ { + self.state_full_pdus(shortstatehash) + .ready_filter_map(|pdu| { + Some(((pdu.kind.to_string().into(), pdu.state_key.clone()?), pdu)) + }) } - pub async fn state_full_pdus(&self, shortstatehash: ShortStateHash) -> Result> { - let short_ids = self.state_full_shortids(shortstatehash).await?; + pub fn state_full_pdus( + &self, + shortstatehash: ShortStateHash, + ) -> impl Stream + Send + '_ { + let short_ids = self + .state_full_shortids(shortstatehash) + .map(|result| result.expect("missing shortstatehash")) + .map(Vec::into_iter) + .map(|iter| iter.map(at!(1))) + .map(IterStream::stream) + .flatten_stream() + .boxed(); - let full_pdus = self - .services + self.services .short - .multi_get_eventid_from_short(short_ids.into_iter().map(at!(1)).stream()) + .multi_get_eventid_from_short(short_ids) .ready_filter_map(Result::ok) - .broad_filter_map(|event_id: OwnedEventId| async move { + .broad_filter_map(move |event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await.ok() }) - .collect() - .await; - - Ok(full_pdus) } /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_full_ids( - &self, + pub fn state_full_ids<'a, Id>( + &'a self, shortstatehash: ShortStateHash, - ) -> Result> + ) -> impl Stream + Send + 'a where - Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned, + Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned + 'a, ::Owned: Borrow, { - let short_ids = self.state_full_shortids(shortstatehash).await?; - - let full_ids = self - .services - .short - .multi_get_eventid_from_short(short_ids.iter().map(at!(1)).stream()) - .zip(short_ids.iter().stream().map(at!(0))) - .ready_filter_map(|(event_id, shortstatekey)| Some((shortstatekey, event_id.ok()?))) - .collect() - .boxed() - .await; - - Ok(full_ids) - } - - #[inline] - pub async fn state_full_shortids( - &self, - shortstatehash: ShortStateHash, - ) -> Result> { let shortids = self - .services - .state_compressor - .load_shortstatehash_info(shortstatehash) - .await - .map_err(|e| err!(Database("Missing state IDs: {e}")))? - .pop() - .expect("there is always one layer") - .full_state - .iter() - .copied() - .map(parse_compressed_state_event) - .collect(); + .state_full_shortids(shortstatehash) + .map(|result| result.expect("missing shortstatehash")) + .map(|vec| vec.into_iter().unzip()) + .boxed() + .shared(); - Ok(shortids) + let shortstatekeys = shortids + .clone() + .map(at!(0)) + .map(Vec::into_iter) + .map(IterStream::stream) + .flatten_stream(); + + let shorteventids = shortids + .map(at!(1)) + .map(Vec::into_iter) + .map(IterStream::stream) + .flatten_stream(); + + self.services + .short + .multi_get_eventid_from_short(shorteventids) + .zip(shortstatekeys) + .ready_filter_map(|(event_id, shortstatekey)| Some((shortstatekey, event_id.ok()?))) } /// Returns a single EventId from `room_id` with key (`event_type`, @@ -264,6 +254,28 @@ impl Service { .await } + #[inline] + pub async fn state_full_shortids( + &self, + shortstatehash: ShortStateHash, + ) -> Result> { + let shortids = self + .services + .state_compressor + .load_shortstatehash_info(shortstatehash) + .await + .map_err(|e| err!(Database("Missing state IDs: {e}")))? + .pop() + .expect("there is always one layer") + .full_state + .iter() + .copied() + .map(parse_compressed_state_event) + .collect(); + + Ok(shortids) + } + /// Returns a single PDU from `room_id` with key (`event_type`, /// `state_key`). pub async fn state_get( @@ -479,27 +491,30 @@ impl Service { /// Returns the full room state. #[tracing::instrument(skip(self), level = "debug")] - pub async fn room_state_full( - &self, - room_id: &RoomId, - ) -> Result> { + pub fn room_state_full<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream> + Send + 'a { self.services .state .get_room_shortstatehash(room_id) - .and_then(|shortstatehash| self.state_full(shortstatehash)) - .map_err(|e| err!(Database("Missing state for {room_id:?}: {e:?}"))) - .await + .map_ok(|shortstatehash| self.state_full(shortstatehash).map(Ok)) + .map_err(move |e| err!(Database("Missing state for {room_id:?}: {e:?}"))) + .try_flatten_stream() } /// Returns the full room state pdus #[tracing::instrument(skip(self), level = "debug")] - pub async fn room_state_full_pdus(&self, room_id: &RoomId) -> Result> { + pub fn room_state_full_pdus<'a>( + &'a self, + room_id: &'a RoomId, + ) -> impl Stream> + Send + 'a { self.services .state .get_room_shortstatehash(room_id) - .and_then(|shortstatehash| self.state_full_pdus(shortstatehash)) - .map_err(|e| err!(Database("Missing state pdus for {room_id:?}: {e:?}"))) - .await + .map_ok(|shortstatehash| self.state_full_pdus(shortstatehash).map(Ok)) + .map_err(move |e| err!(Database("Missing state for {room_id:?}: {e:?}"))) + .try_flatten_stream() } /// Returns a single EventId from `room_id` with key (`event_type`,