diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 8cb4e586..1b0d0d58 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -429,60 +429,54 @@ impl Service { sender: &UserId, state_key: Option<&str>, content: &serde_json::value::RawValue, - ) -> Result>> { + ) -> Result> { let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else { return Ok(HashMap::new()); }; - let mut sauthevents: HashMap<_, _> = - state_res::auth_types_for_event(kind, sender, state_key, content)? - .iter() - .stream() - .broad_filter_map(|(event_type, state_key)| { - self.services - .short - .get_shortstatekey(event_type, state_key) - .map_ok(move |ssk| (ssk, (event_type, state_key))) - .map(Result::ok) - }) - .map(|(ssk, (event_type, state_key))| { - (ssk, (event_type.to_owned(), state_key.to_owned())) - }) - .collect() - .await; + let auth_types = state_res::auth_types_for_event(kind, sender, state_key, content)?; + + let sauthevents: HashMap<_, _> = auth_types + .iter() + .stream() + .broad_filter_map(|(event_type, state_key)| { + self.services + .short + .get_shortstatekey(event_type, state_key) + .map_ok(move |ssk| (ssk, (event_type, state_key))) + .map(Result::ok) + }) + .collect() + .await; let (state_keys, event_ids): (Vec<_>, Vec<_>) = self .services .state_accessor .state_full_shortids(shortstatehash) - .await - .map_err(|e| err!(Database(error!(?room_id, ?shortstatehash, "{e:?}"))))? - .into_iter() - .filter_map(|(shortstatekey, shorteventid)| { + .ready_filter_map(Result::ok) + .ready_filter_map(|(shortstatekey, shorteventid)| { sauthevents - .remove(&shortstatekey) - .map(|(event_type, state_key)| ((event_type, state_key), shorteventid)) + .get(&shortstatekey) + .map(|(ty, sk)| ((ty, sk), shorteventid)) }) - .unzip(); + .unzip() + .await; - let auth_pdus = self - .services + self.services .short .multi_get_eventid_from_short(event_ids.into_iter().stream()) .zip(state_keys.into_iter().stream()) - .ready_filter_map(|(event_id, tsk)| Some((tsk, event_id.ok()?))) - .broad_filter_map(|(tsk, event_id): (_, OwnedEventId)| async move { + .ready_filter_map(|(event_id, (ty, sk))| Some(((ty, sk), event_id.ok()?))) + .broad_filter_map(|((ty, sk), event_id): (_, OwnedEventId)| async move { self.services .timeline .get_pdu(&event_id) .await - .map(Arc::new) - .map(move |pdu| (tsk, pdu)) + .map(move |pdu| (((*ty).clone(), (*sk).clone()), pdu)) .ok() }) .collect() - .await; - - Ok(auth_pdus) + .map(Ok) + .await } } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index 0f5520bb..98aac138 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -1,6 +1,7 @@ use std::{ borrow::Borrow, fmt::Write, + ops::Deref, sync::{Arc, Mutex as StdMutex, Mutex}, }; @@ -10,8 +11,7 @@ use conduwuit::{ utils, utils::{ math::{usize_from_f64, Expected}, - stream::BroadbandExt, - IterStream, ReadyExt, + stream::{BroadbandExt, IterStream, ReadyExt, TryExpect}, }, Err, Error, PduEvent, Result, }; @@ -158,12 +158,8 @@ impl Service { ) -> impl Stream + Send + '_ { let short_ids = self .state_full_shortids(shortstatehash) - .map(|result| result.expect("missing shortstatehash")) - .map(Vec::into_iter) - .map(|iter| iter.map(at!(1))) - .map(IterStream::stream) - .flatten_stream() - .boxed(); + .expect_ok() + .map(at!(1)); self.services .short @@ -187,9 +183,8 @@ impl Service { { let shortids = self .state_full_shortids(shortstatehash) - .map(|result| result.expect("missing shortstatehash")) - .map(|vec| vec.into_iter().unzip()) - .boxed() + .expect_ok() + .unzip() .shared(); let shortstatekeys = shortids @@ -255,25 +250,25 @@ impl Service { } #[inline] - pub async fn state_full_shortids( + pub fn state_full_shortids( &self, shortstatehash: ShortStateHash, - ) -> Result> { - let shortids = self - .services + ) -> impl Stream> + Send + '_ { + self.services .state_compressor .load_shortstatehash_info(shortstatehash) - .await - .map_err(|e| err!(Database("Missing state IDs: {e}")))? - .pop() - .expect("there is always one layer") - .full_state - .iter() - .copied() - .map(parse_compressed_state_event) - .collect(); - - Ok(shortids) + .map_err(|e| err!(Database("Missing state IDs: {e}"))) + .map_ok(|vec| vec.last().expect("at least one layer").full_state.clone()) + .map_ok(|full_state| { + full_state + .deref() + .iter() + .copied() + .map(parse_compressed_state_event) + .collect() + }) + .map_ok(|vec: Vec<_>| vec.into_iter().try_stream()) + .try_flatten_stream() } /// Returns a single PDU from `room_id` with key (`event_type`,