parallelize state-res pre-gathering

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-01-29 23:07:12 +00:00
parent 50acfe7832
commit 3c8376d897
2 changed files with 123 additions and 113 deletions

View file

@ -5,11 +5,11 @@ use std::{
}; };
use conduwuit::{ use conduwuit::{
debug, err, implement, err, implement, trace,
utils::stream::{automatic_width, IterStream, ReadyExt, TryWidebandExt, WidebandExt}, utils::stream::{automatic_width, IterStream, ReadyExt, TryWidebandExt, WidebandExt},
Result, Error, Result,
}; };
use futures::{FutureExt, StreamExt, TryStreamExt}; use futures::{future::try_join, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{ use ruma::{
state_res::{self, StateMap}, state_res::{self, StateMap},
OwnedEventId, RoomId, RoomVersionId, OwnedEventId, RoomId, RoomVersionId,
@ -25,13 +25,13 @@ pub async fn resolve_state(
room_version_id: &RoomVersionId, room_version_id: &RoomVersionId,
incoming_state: HashMap<u64, OwnedEventId>, incoming_state: HashMap<u64, OwnedEventId>,
) -> Result<Arc<HashSet<CompressedStateEvent>>> { ) -> Result<Arc<HashSet<CompressedStateEvent>>> {
debug!("Loading current room state ids"); trace!("Loading current room state ids");
let current_sstatehash = self let current_sstatehash = self
.services .services
.state .state
.get_room_shortstatehash(room_id) .get_room_shortstatehash(room_id)
.await .map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))
.map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?; .await?;
let current_state_ids: HashMap<_, _> = self let current_state_ids: HashMap<_, _> = self
.services .services
@ -40,8 +40,9 @@ pub async fn resolve_state(
.collect() .collect()
.await; .await;
trace!("Loading fork states");
let fork_states = [current_state_ids, incoming_state]; let fork_states = [current_state_ids, incoming_state];
let auth_chain_sets: Vec<HashSet<OwnedEventId>> = fork_states let auth_chain_sets = fork_states
.iter() .iter()
.try_stream() .try_stream()
.wide_and_then(|state| { .wide_and_then(|state| {
@ -50,36 +51,33 @@ pub async fn resolve_state(
.event_ids_iter(room_id, state.values().map(Borrow::borrow)) .event_ids_iter(room_id, state.values().map(Borrow::borrow))
.try_collect() .try_collect()
}) })
.try_collect() .try_collect::<Vec<HashSet<OwnedEventId>>>();
.await?;
debug!("Loading fork states"); let fork_states = fork_states
let fork_states: Vec<StateMap<OwnedEventId>> = fork_states .iter()
.into_iter()
.stream() .stream()
.wide_then(|fork_state| async move { .wide_then(|fork_state| {
let shortstatekeys = fork_state.keys().copied().stream(); let shortstatekeys = fork_state.keys().copied().stream();
let event_ids = fork_state.values().cloned().stream();
let event_ids = fork_state.values().cloned().stream().boxed();
self.services self.services
.short .short
.multi_get_statekey_from_short(shortstatekeys) .multi_get_statekey_from_short(shortstatekeys)
.zip(event_ids) .zip(event_ids)
.ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id))) .ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id)))
.collect() .collect()
.await
}) })
.collect() .map(Ok::<_, Error>)
.await; .try_collect::<Vec<StateMap<OwnedEventId>>>();
debug!("Resolving state"); let (fork_states, auth_chain_sets) = try_join(fork_states, auth_chain_sets).await?;
trace!("Resolving state");
let state = self let state = self
.state_resolution(room_version_id, &fork_states, &auth_chain_sets) .state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets)
.boxed() .boxed()
.await?; .await?;
debug!("State resolution done."); trace!("State resolution done.");
let state_events: Vec<_> = state let state_events: Vec<_> = state
.iter() .iter()
.stream() .stream()
@ -92,7 +90,7 @@ pub async fn resolve_state(
.collect() .collect()
.await; .await;
debug!("Compressing state..."); trace!("Compressing state...");
let new_room_state: HashSet<_> = self let new_room_state: HashSet<_> = self
.services .services
.state_compressor .state_compressor
@ -109,20 +107,23 @@ pub async fn resolve_state(
#[implement(super::Service)] #[implement(super::Service)]
#[tracing::instrument(name = "ruma", level = "debug", skip_all)] #[tracing::instrument(name = "ruma", level = "debug", skip_all)]
pub async fn state_resolution( pub async fn state_resolution<'a, StateSets>(
&self, &'a self,
room_version: &RoomVersionId, room_version: &'a RoomVersionId,
state_sets: &[StateMap<OwnedEventId>], state_sets: StateSets,
auth_chain_sets: &[HashSet<OwnedEventId>], auth_chain_sets: &'a [HashSet<OwnedEventId>],
) -> Result<StateMap<OwnedEventId>> { ) -> Result<StateMap<OwnedEventId>>
where
StateSets: Iterator<Item = &'a StateMap<OwnedEventId>> + Clone + Send,
{
state_res::resolve( state_res::resolve(
room_version, room_version,
state_sets.iter(), state_sets,
auth_chain_sets, auth_chain_sets,
&|event_id| self.event_fetch(event_id), &|event_id| self.event_fetch(event_id),
&|event_id| self.event_exists(event_id), &|event_id| self.event_exists(event_id),
automatic_width(), automatic_width(),
) )
.await
.map_err(|e| err!(error!("State resolution failed: {e:?}"))) .map_err(|e| err!(error!("State resolution failed: {e:?}")))
.await
} }

View file

@ -1,18 +1,20 @@
use std::{ use std::{
borrow::Borrow, borrow::Borrow,
collections::{HashMap, HashSet}, collections::{HashMap, HashSet},
iter::Iterator,
sync::Arc, sync::Arc,
}; };
use conduwuit::{ use conduwuit::{
debug, err, implement, debug, err, implement, trace,
result::LogErr, utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt},
utils::stream::{BroadbandExt, IterStream},
PduEvent, Result, PduEvent, Result,
}; };
use futures::{FutureExt, StreamExt, TryStreamExt}; use futures::{future::try_join, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{state_res::StateMap, OwnedEventId, RoomId, RoomVersionId}; use ruma::{state_res::StateMap, OwnedEventId, RoomId, RoomVersionId};
use crate::rooms::short::ShortStateHash;
// TODO: if we know the prev_events of the incoming event we can avoid the // TODO: if we know the prev_events of the incoming event we can avoid the
#[implement(super::Service)] #[implement(super::Service)]
// request and build the state from a known point and resolve if > 1 prev_event // request and build the state from a known point and resolve if > 1 prev_event
@ -70,36 +72,73 @@ pub(super) async fn state_at_incoming_resolved(
room_id: &RoomId, room_id: &RoomId,
room_version_id: &RoomVersionId, room_version_id: &RoomVersionId,
) -> Result<Option<HashMap<u64, OwnedEventId>>> { ) -> Result<Option<HashMap<u64, OwnedEventId>>> {
debug!("Calculating state at event using state res"); trace!("Calculating extremity statehashes...");
let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len()); let Ok(extremity_sstatehashes) = incoming_pdu
.prev_events
let mut okay = true; .iter()
for prev_eventid in &incoming_pdu.prev_events { .try_stream()
let Ok(prev_event) = self.services.timeline.get_pdu(prev_eventid).await else { .broad_and_then(|prev_eventid| {
okay = false; self.services
break; .timeline
}; .get_pdu(prev_eventid)
.map_ok(move |prev_event| (prev_eventid, prev_event))
let Ok(sstatehash) = self })
.services .broad_and_then(|(prev_eventid, prev_event)| {
self.services
.state_accessor .state_accessor
.pdu_shortstatehash(prev_eventid) .pdu_shortstatehash(prev_eventid)
.map_ok(move |sstatehash| (sstatehash, prev_event))
})
.try_collect::<HashMap<_, _>>()
.await .await
else { else {
okay = false; return Ok(None);
break;
}; };
extremity_sstatehashes.insert(sstatehash, prev_event); trace!("Calculating fork states...");
} let (fork_states, auth_chain_sets): (Vec<StateMap<_>>, Vec<HashSet<_>>) =
extremity_sstatehashes
.into_iter()
.try_stream()
.wide_and_then(|(sstatehash, prev_event)| {
self.state_at_incoming_fork(room_id, sstatehash, prev_event)
})
.try_collect()
.map_ok(Vec::into_iter)
.map_ok(Iterator::unzip)
.await?;
if !okay { let Ok(new_state) = self
.state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets)
.boxed()
.await
else {
return Ok(None); return Ok(None);
} };
let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); new_state
let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); .into_iter()
for (sstatehash, prev_event) in extremity_sstatehashes { .stream()
.broad_then(|((event_type, state_key), event_id)| async move {
self.services
.short
.get_or_create_shortstatekey(&event_type, &state_key)
.map(move |shortstatekey| (shortstatekey, event_id))
.await
})
.collect()
.map(Some)
.map(Ok)
.await
}
#[implement(super::Service)]
async fn state_at_incoming_fork(
&self,
room_id: &RoomId,
sstatehash: ShortStateHash,
prev_event: PduEvent,
) -> Result<(StateMap<OwnedEventId>, HashSet<OwnedEventId>)> {
let mut leaf_state: HashMap<_, _> = self let mut leaf_state: HashMap<_, _> = self
.services .services
.state_accessor .state_accessor
@ -119,54 +158,24 @@ pub(super) async fn state_at_incoming_resolved(
// Now it's the state after the pdu // Now it's the state after the pdu
} }
let mut state = StateMap::with_capacity(leaf_state.len()); let auth_chain = self
let mut starting_events = Vec::with_capacity(leaf_state.len());
for (k, id) in &leaf_state {
if let Ok((ty, st_key)) = self
.services
.short
.get_statekey_from_short(*k)
.await
.log_err()
{
// FIXME: Undo .to_string().into() when StateMap
// is updated to use StateEventType
state.insert((ty.to_string().into(), st_key), id.clone());
}
starting_events.push(id.borrow());
}
let auth_chain: HashSet<OwnedEventId> = self
.services .services
.auth_chain .auth_chain
.event_ids_iter(room_id, starting_events.into_iter()) .event_ids_iter(room_id, leaf_state.values().map(Borrow::borrow))
.try_collect() .try_collect();
.await?;
auth_chain_sets.push(auth_chain); let fork_state = leaf_state
fork_states.push(state);
}
let Ok(new_state) = self
.state_resolution(room_version_id, &fork_states, &auth_chain_sets)
.boxed()
.await
else {
return Ok(None);
};
new_state
.iter() .iter()
.stream() .stream()
.broad_then(|((event_type, state_key), event_id)| { .broad_then(|(k, id)| {
self.services self.services
.short .short
.get_or_create_shortstatekey(event_type, state_key) .get_statekey_from_short(*k)
.map(move |shortstatekey| (shortstatekey, event_id.clone())) .map_ok(|(ty, sk)| ((ty, sk), id.clone()))
}) })
.ready_filter_map(Result::ok)
.collect() .collect()
.map(Some) .map(Ok);
.map(Ok)
.await try_join(fork_state, auth_chain).await
} }