diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs index 03f7e822..c3de5f2f 100644 --- a/src/service/rooms/event_handler/resolve_state.rs +++ b/src/service/rooms/event_handler/resolve_state.rs @@ -5,11 +5,11 @@ use std::{ }; use conduwuit::{ - debug, err, implement, + err, implement, trace, utils::stream::{automatic_width, IterStream, ReadyExt, TryWidebandExt, WidebandExt}, - Result, + Error, Result, }; -use futures::{FutureExt, StreamExt, TryStreamExt}; +use futures::{future::try_join, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{ state_res::{self, StateMap}, OwnedEventId, RoomId, RoomVersionId, @@ -25,13 +25,13 @@ pub async fn resolve_state( room_version_id: &RoomVersionId, incoming_state: HashMap, ) -> Result>> { - debug!("Loading current room state ids"); + trace!("Loading current room state ids"); let current_sstatehash = self .services .state .get_room_shortstatehash(room_id) - .await - .map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?; + .map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}")))) + .await?; let current_state_ids: HashMap<_, _> = self .services @@ -40,8 +40,9 @@ pub async fn resolve_state( .collect() .await; + trace!("Loading fork states"); let fork_states = [current_state_ids, incoming_state]; - let auth_chain_sets: Vec> = fork_states + let auth_chain_sets = fork_states .iter() .try_stream() .wide_and_then(|state| { @@ -50,36 +51,33 @@ pub async fn resolve_state( .event_ids_iter(room_id, state.values().map(Borrow::borrow)) .try_collect() }) - .try_collect() - .await?; + .try_collect::>>(); - debug!("Loading fork states"); - let fork_states: Vec> = fork_states - .into_iter() + let fork_states = fork_states + .iter() .stream() - .wide_then(|fork_state| async move { + .wide_then(|fork_state| { let shortstatekeys = fork_state.keys().copied().stream(); - - let event_ids = fork_state.values().cloned().stream().boxed(); - + let event_ids = fork_state.values().cloned().stream(); self.services .short .multi_get_statekey_from_short(shortstatekeys) .zip(event_ids) .ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id))) .collect() - .await }) - .collect() - .await; + .map(Ok::<_, Error>) + .try_collect::>>(); - debug!("Resolving state"); + let (fork_states, auth_chain_sets) = try_join(fork_states, auth_chain_sets).await?; + + trace!("Resolving state"); let state = self - .state_resolution(room_version_id, &fork_states, &auth_chain_sets) + .state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets) .boxed() .await?; - debug!("State resolution done."); + trace!("State resolution done."); let state_events: Vec<_> = state .iter() .stream() @@ -92,7 +90,7 @@ pub async fn resolve_state( .collect() .await; - debug!("Compressing state..."); + trace!("Compressing state..."); let new_room_state: HashSet<_> = self .services .state_compressor @@ -109,20 +107,23 @@ pub async fn resolve_state( #[implement(super::Service)] #[tracing::instrument(name = "ruma", level = "debug", skip_all)] -pub async fn state_resolution( - &self, - room_version: &RoomVersionId, - state_sets: &[StateMap], - auth_chain_sets: &[HashSet], -) -> Result> { +pub async fn state_resolution<'a, StateSets>( + &'a self, + room_version: &'a RoomVersionId, + state_sets: StateSets, + auth_chain_sets: &'a [HashSet], +) -> Result> +where + StateSets: Iterator> + Clone + Send, +{ state_res::resolve( room_version, - state_sets.iter(), + state_sets, auth_chain_sets, &|event_id| self.event_fetch(event_id), &|event_id| self.event_exists(event_id), automatic_width(), ) - .await .map_err(|e| err!(error!("State resolution failed: {e:?}"))) + .await } diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs index 8730232a..8ae6354c 100644 --- a/src/service/rooms/event_handler/state_at_incoming.rs +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -1,18 +1,20 @@ use std::{ borrow::Borrow, collections::{HashMap, HashSet}, + iter::Iterator, sync::Arc, }; use conduwuit::{ - debug, err, implement, - result::LogErr, - utils::stream::{BroadbandExt, IterStream}, + debug, err, implement, trace, + utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt}, PduEvent, Result, }; -use futures::{FutureExt, StreamExt, TryStreamExt}; +use futures::{future::try_join, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use ruma::{state_res::StateMap, OwnedEventId, RoomId, RoomVersionId}; +use crate::rooms::short::ShortStateHash; + // TODO: if we know the prev_events of the incoming event we can avoid the #[implement(super::Service)] // request and build the state from a known point and resolve if > 1 prev_event @@ -70,86 +72,44 @@ pub(super) async fn state_at_incoming_resolved( room_id: &RoomId, room_version_id: &RoomVersionId, ) -> Result>> { - debug!("Calculating state at event using state res"); - let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len()); - - let mut okay = true; - for prev_eventid in &incoming_pdu.prev_events { - let Ok(prev_event) = self.services.timeline.get_pdu(prev_eventid).await else { - okay = false; - break; - }; - - let Ok(sstatehash) = self - .services - .state_accessor - .pdu_shortstatehash(prev_eventid) - .await - else { - okay = false; - break; - }; - - extremity_sstatehashes.insert(sstatehash, prev_event); - } - - if !okay { + trace!("Calculating extremity statehashes..."); + let Ok(extremity_sstatehashes) = incoming_pdu + .prev_events + .iter() + .try_stream() + .broad_and_then(|prev_eventid| { + self.services + .timeline + .get_pdu(prev_eventid) + .map_ok(move |prev_event| (prev_eventid, prev_event)) + }) + .broad_and_then(|(prev_eventid, prev_event)| { + self.services + .state_accessor + .pdu_shortstatehash(prev_eventid) + .map_ok(move |sstatehash| (sstatehash, prev_event)) + }) + .try_collect::>() + .await + else { return Ok(None); - } + }; - let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); - let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); - for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: HashMap<_, _> = self - .services - .state_accessor - .state_full_ids(sstatehash) - .collect() - .await; - - if let Some(state_key) = &prev_event.state_key { - let shortstatekey = self - .services - .short - .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key) - .await; - - let event_id = &prev_event.event_id; - leaf_state.insert(shortstatekey, event_id.clone()); - // Now it's the state after the pdu - } - - let mut state = StateMap::with_capacity(leaf_state.len()); - let mut starting_events = Vec::with_capacity(leaf_state.len()); - for (k, id) in &leaf_state { - if let Ok((ty, st_key)) = self - .services - .short - .get_statekey_from_short(*k) - .await - .log_err() - { - // FIXME: Undo .to_string().into() when StateMap - // is updated to use StateEventType - state.insert((ty.to_string().into(), st_key), id.clone()); - } - - starting_events.push(id.borrow()); - } - - let auth_chain: HashSet = self - .services - .auth_chain - .event_ids_iter(room_id, starting_events.into_iter()) + trace!("Calculating fork states..."); + let (fork_states, auth_chain_sets): (Vec>, Vec>) = + extremity_sstatehashes + .into_iter() + .try_stream() + .wide_and_then(|(sstatehash, prev_event)| { + self.state_at_incoming_fork(room_id, sstatehash, prev_event) + }) .try_collect() + .map_ok(Vec::into_iter) + .map_ok(Iterator::unzip) .await?; - auth_chain_sets.push(auth_chain); - fork_states.push(state); - } - let Ok(new_state) = self - .state_resolution(room_version_id, &fork_states, &auth_chain_sets) + .state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets) .boxed() .await else { @@ -157,16 +117,65 @@ pub(super) async fn state_at_incoming_resolved( }; new_state - .iter() + .into_iter() .stream() - .broad_then(|((event_type, state_key), event_id)| { + .broad_then(|((event_type, state_key), event_id)| async move { self.services .short - .get_or_create_shortstatekey(event_type, state_key) - .map(move |shortstatekey| (shortstatekey, event_id.clone())) + .get_or_create_shortstatekey(&event_type, &state_key) + .map(move |shortstatekey| (shortstatekey, event_id)) + .await }) .collect() .map(Some) .map(Ok) .await } + +#[implement(super::Service)] +async fn state_at_incoming_fork( + &self, + room_id: &RoomId, + sstatehash: ShortStateHash, + prev_event: PduEvent, +) -> Result<(StateMap, HashSet)> { + let mut leaf_state: HashMap<_, _> = self + .services + .state_accessor + .state_full_ids(sstatehash) + .collect() + .await; + + if let Some(state_key) = &prev_event.state_key { + let shortstatekey = self + .services + .short + .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key) + .await; + + let event_id = &prev_event.event_id; + leaf_state.insert(shortstatekey, event_id.clone()); + // Now it's the state after the pdu + } + + let auth_chain = self + .services + .auth_chain + .event_ids_iter(room_id, leaf_state.values().map(Borrow::borrow)) + .try_collect(); + + let fork_state = leaf_state + .iter() + .stream() + .broad_then(|(k, id)| { + self.services + .short + .get_statekey_from_short(*k) + .map_ok(|(ty, sk)| ((ty, sk), id.clone())) + }) + .ready_filter_map(Result::ok) + .collect() + .map(Ok); + + try_join(fork_state, auth_chain).await +}