diff --git a/src/service/rooms/state_accessor/state.rs b/src/service/rooms/state_accessor/state.rs index c47a5693..3cf168c1 100644 --- a/src/service/rooms/state_accessor/state.rs +++ b/src/service/rooms/state_accessor/state.rs @@ -9,7 +9,7 @@ use conduwuit::{ PduEvent, Result, }; use database::Deserialized; -use futures::{future::try_join, FutureExt, Stream, StreamExt, TryFutureExt}; +use futures::{future::try_join, pin_mut, FutureExt, Stream, StreamExt, TryFutureExt}; use ruma::{ events::{ room::member::{MembershipState, RoomMemberEventContent}, @@ -69,7 +69,6 @@ where } #[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] pub async fn state_contains( &self, shortstatehash: ShortStateHash, @@ -90,7 +89,18 @@ pub async fn state_contains( } #[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] +pub async fn state_contains_type( + &self, + shortstatehash: ShortStateHash, + event_type: &StateEventType, +) -> bool { + let state_keys = self.state_keys(shortstatehash, event_type); + + pin_mut!(state_keys); + state_keys.next().await.is_some() +} + +#[implement(super::Service)] pub async fn state_contains_shortstatekey( &self, shortstatehash: ShortStateHash, @@ -125,7 +135,6 @@ pub async fn state_get( /// Returns a single EventId from `room_id` with key (`event_type`, /// `state_key`). #[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] pub async fn state_get_id( &self, shortstatehash: ShortStateHash, @@ -149,7 +158,6 @@ where /// Returns a single EventId from `room_id` with key (`event_type`, /// `state_key`). #[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] pub async fn state_get_shortid( &self, shortstatehash: ShortStateHash, @@ -177,6 +185,103 @@ pub async fn state_get_shortid( .await? } +/// Iterates the state_keys for an event_type in the state; current state +/// event_id included. +#[implement(super::Service)] +pub fn state_keys_with_ids<'a, Id>( + &'a self, + shortstatehash: ShortStateHash, + event_type: &'a StateEventType, +) -> impl Stream + Send + 'a +where + Id: for<'de> Deserialize<'de> + Send + Sized + ToOwned + 'a, + ::Owned: Borrow, +{ + let state_keys_with_short_ids = self + .state_keys_with_shortids(shortstatehash, event_type) + .unzip() + .map(|(ssks, sids): (Vec, Vec)| (ssks, sids)) + .shared(); + + let state_keys = state_keys_with_short_ids + .clone() + .map(at!(0)) + .map(Vec::into_iter) + .map(IterStream::stream) + .flatten_stream(); + + let shorteventids = state_keys_with_short_ids + .map(at!(1)) + .map(Vec::into_iter) + .map(IterStream::stream) + .flatten_stream(); + + self.services + .short + .multi_get_eventid_from_short(shorteventids) + .zip(state_keys) + .ready_filter_map(|(eid, sk)| eid.map(move |eid| (sk, eid)).ok()) +} + +/// Iterates the state_keys for an event_type in the state; current state +/// event_id included. +#[implement(super::Service)] +pub fn state_keys_with_shortids<'a>( + &'a self, + shortstatehash: ShortStateHash, + event_type: &'a StateEventType, +) -> impl Stream + Send + 'a { + let short_ids = self + .state_full_shortids(shortstatehash) + .expect_ok() + .unzip() + .map(|(ssks, sids): (Vec, Vec)| (ssks, sids)) + .shared(); + + let shortstatekeys = short_ids + .clone() + .map(at!(0)) + .map(Vec::into_iter) + .map(IterStream::stream) + .flatten_stream(); + + let shorteventids = short_ids + .map(at!(1)) + .map(Vec::into_iter) + .map(IterStream::stream) + .flatten_stream(); + + self.services + .short + .multi_get_statekey_from_short(shortstatekeys) + .zip(shorteventids) + .ready_filter_map(|(res, id)| res.map(|res| (res, id)).ok()) + .ready_filter_map(move |((event_type_, state_key), event_id)| { + event_type_.eq(event_type).then_some((state_key, event_id)) + }) +} + +/// Iterates the state_keys for an event_type in the state +#[implement(super::Service)] +pub fn state_keys<'a>( + &'a self, + shortstatehash: ShortStateHash, + event_type: &'a StateEventType, +) -> impl Stream + Send + 'a { + let short_ids = self + .state_full_shortids(shortstatehash) + .expect_ok() + .map(at!(0)); + + self.services + .short + .multi_get_statekey_from_short(short_ids) + .ready_filter_map(Result::ok) + .ready_filter_map(move |(event_type_, state_key)| { + event_type_.eq(event_type).then_some(state_key) + }) +} + /// Returns the state events removed between the interval (present in .0 but /// not in .1) #[implement(super::Service)] @@ -191,11 +296,10 @@ pub fn state_removed( /// Returns the state events added between the interval (present in .1 but /// not in .0) #[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] -pub fn state_added<'a>( - &'a self, +pub fn state_added( + &self, shortstatehash: pair_of!(ShortStateHash), -) -> impl Stream + Send + 'a { +) -> impl Stream + Send + '_ { let a = self.load_full_state(shortstatehash.0); let b = self.load_full_state(shortstatehash.1); try_join(a, b) @@ -239,7 +343,6 @@ pub fn state_full_pdus( /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[implement(super::Service)] -#[tracing::instrument(skip(self), level = "debug")] pub fn state_full_ids<'a, Id>( &'a self, shortstatehash: ShortStateHash, @@ -293,6 +396,7 @@ pub fn state_full_shortids( } #[implement(super::Service)] +#[tracing::instrument(name = "load", level = "debug", skip(self))] async fn load_full_state(&self, shortstatehash: ShortStateHash) -> Result> { self.services .state_compressor