From 3c8376d897e6a1b9b6b61f5ada05b2afec1ab937 Mon Sep 17 00:00:00 2001
From: Jason Volk <jason@zemos.net>
Date: Wed, 29 Jan 2025 23:07:12 +0000
Subject: [PATCH] parallelize state-res pre-gathering

Signed-off-by: Jason Volk <jason@zemos.net>
---
 .../rooms/event_handler/resolve_state.rs      |  63 +++----
 .../rooms/event_handler/state_at_incoming.rs  | 173 +++++++++---------
 2 files changed, 123 insertions(+), 113 deletions(-)

diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs
index 03f7e822..c3de5f2f 100644
--- a/src/service/rooms/event_handler/resolve_state.rs
+++ b/src/service/rooms/event_handler/resolve_state.rs
@@ -5,11 +5,11 @@ use std::{
 };
 
 use conduwuit::{
-	debug, err, implement,
+	err, implement, trace,
 	utils::stream::{automatic_width, IterStream, ReadyExt, TryWidebandExt, WidebandExt},
-	Result,
+	Error, Result,
 };
-use futures::{FutureExt, StreamExt, TryStreamExt};
+use futures::{future::try_join, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
 use ruma::{
 	state_res::{self, StateMap},
 	OwnedEventId, RoomId, RoomVersionId,
@@ -25,13 +25,13 @@ pub async fn resolve_state(
 	room_version_id: &RoomVersionId,
 	incoming_state: HashMap<u64, OwnedEventId>,
 ) -> Result<Arc<HashSet<CompressedStateEvent>>> {
-	debug!("Loading current room state ids");
+	trace!("Loading current room state ids");
 	let current_sstatehash = self
 		.services
 		.state
 		.get_room_shortstatehash(room_id)
-		.await
-		.map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))?;
+		.map_err(|e| err!(Database(error!("No state for {room_id:?}: {e:?}"))))
+		.await?;
 
 	let current_state_ids: HashMap<_, _> = self
 		.services
@@ -40,8 +40,9 @@ pub async fn resolve_state(
 		.collect()
 		.await;
 
+	trace!("Loading fork states");
 	let fork_states = [current_state_ids, incoming_state];
-	let auth_chain_sets: Vec<HashSet<OwnedEventId>> = fork_states
+	let auth_chain_sets = fork_states
 		.iter()
 		.try_stream()
 		.wide_and_then(|state| {
@@ -50,36 +51,33 @@ pub async fn resolve_state(
 				.event_ids_iter(room_id, state.values().map(Borrow::borrow))
 				.try_collect()
 		})
-		.try_collect()
-		.await?;
+		.try_collect::<Vec<HashSet<OwnedEventId>>>();
 
-	debug!("Loading fork states");
-	let fork_states: Vec<StateMap<OwnedEventId>> = fork_states
-		.into_iter()
+	let fork_states = fork_states
+		.iter()
 		.stream()
-		.wide_then(|fork_state| async move {
+		.wide_then(|fork_state| {
 			let shortstatekeys = fork_state.keys().copied().stream();
-
-			let event_ids = fork_state.values().cloned().stream().boxed();
-
+			let event_ids = fork_state.values().cloned().stream();
 			self.services
 				.short
 				.multi_get_statekey_from_short(shortstatekeys)
 				.zip(event_ids)
 				.ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id)))
 				.collect()
-				.await
 		})
-		.collect()
-		.await;
+		.map(Ok::<_, Error>)
+		.try_collect::<Vec<StateMap<OwnedEventId>>>();
 
-	debug!("Resolving state");
+	let (fork_states, auth_chain_sets) = try_join(fork_states, auth_chain_sets).await?;
+
+	trace!("Resolving state");
 	let state = self
-		.state_resolution(room_version_id, &fork_states, &auth_chain_sets)
+		.state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets)
 		.boxed()
 		.await?;
 
-	debug!("State resolution done.");
+	trace!("State resolution done.");
 	let state_events: Vec<_> = state
 		.iter()
 		.stream()
@@ -92,7 +90,7 @@ pub async fn resolve_state(
 		.collect()
 		.await;
 
-	debug!("Compressing state...");
+	trace!("Compressing state...");
 	let new_room_state: HashSet<_> = self
 		.services
 		.state_compressor
@@ -109,20 +107,23 @@ pub async fn resolve_state(
 
 #[implement(super::Service)]
 #[tracing::instrument(name = "ruma", level = "debug", skip_all)]
-pub async fn state_resolution(
-	&self,
-	room_version: &RoomVersionId,
-	state_sets: &[StateMap<OwnedEventId>],
-	auth_chain_sets: &[HashSet<OwnedEventId>],
-) -> Result<StateMap<OwnedEventId>> {
+pub async fn state_resolution<'a, StateSets>(
+	&'a self,
+	room_version: &'a RoomVersionId,
+	state_sets: StateSets,
+	auth_chain_sets: &'a [HashSet<OwnedEventId>],
+) -> Result<StateMap<OwnedEventId>>
+where
+	StateSets: Iterator<Item = &'a StateMap<OwnedEventId>> + Clone + Send,
+{
 	state_res::resolve(
 		room_version,
-		state_sets.iter(),
+		state_sets,
 		auth_chain_sets,
 		&|event_id| self.event_fetch(event_id),
 		&|event_id| self.event_exists(event_id),
 		automatic_width(),
 	)
-	.await
 	.map_err(|e| err!(error!("State resolution failed: {e:?}")))
+	.await
 }
diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs
index 8730232a..8ae6354c 100644
--- a/src/service/rooms/event_handler/state_at_incoming.rs
+++ b/src/service/rooms/event_handler/state_at_incoming.rs
@@ -1,18 +1,20 @@
 use std::{
 	borrow::Borrow,
 	collections::{HashMap, HashSet},
+	iter::Iterator,
 	sync::Arc,
 };
 
 use conduwuit::{
-	debug, err, implement,
-	result::LogErr,
-	utils::stream::{BroadbandExt, IterStream},
+	debug, err, implement, trace,
+	utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt},
 	PduEvent, Result,
 };
-use futures::{FutureExt, StreamExt, TryStreamExt};
+use futures::{future::try_join, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
 use ruma::{state_res::StateMap, OwnedEventId, RoomId, RoomVersionId};
 
+use crate::rooms::short::ShortStateHash;
+
 // TODO: if we know the prev_events of the incoming event we can avoid the
 #[implement(super::Service)]
 // request and build the state from a known point and resolve if > 1 prev_event
@@ -70,86 +72,44 @@ pub(super) async fn state_at_incoming_resolved(
 	room_id: &RoomId,
 	room_version_id: &RoomVersionId,
 ) -> Result<Option<HashMap<u64, OwnedEventId>>> {
-	debug!("Calculating state at event using state res");
-	let mut extremity_sstatehashes = HashMap::with_capacity(incoming_pdu.prev_events.len());
-
-	let mut okay = true;
-	for prev_eventid in &incoming_pdu.prev_events {
-		let Ok(prev_event) = self.services.timeline.get_pdu(prev_eventid).await else {
-			okay = false;
-			break;
-		};
-
-		let Ok(sstatehash) = self
-			.services
-			.state_accessor
-			.pdu_shortstatehash(prev_eventid)
-			.await
-		else {
-			okay = false;
-			break;
-		};
-
-		extremity_sstatehashes.insert(sstatehash, prev_event);
-	}
-
-	if !okay {
+	trace!("Calculating extremity statehashes...");
+	let Ok(extremity_sstatehashes) = incoming_pdu
+		.prev_events
+		.iter()
+		.try_stream()
+		.broad_and_then(|prev_eventid| {
+			self.services
+				.timeline
+				.get_pdu(prev_eventid)
+				.map_ok(move |prev_event| (prev_eventid, prev_event))
+		})
+		.broad_and_then(|(prev_eventid, prev_event)| {
+			self.services
+				.state_accessor
+				.pdu_shortstatehash(prev_eventid)
+				.map_ok(move |sstatehash| (sstatehash, prev_event))
+		})
+		.try_collect::<HashMap<_, _>>()
+		.await
+	else {
 		return Ok(None);
-	}
+	};
 
-	let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len());
-	let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len());
-	for (sstatehash, prev_event) in extremity_sstatehashes {
-		let mut leaf_state: HashMap<_, _> = self
-			.services
-			.state_accessor
-			.state_full_ids(sstatehash)
-			.collect()
-			.await;
-
-		if let Some(state_key) = &prev_event.state_key {
-			let shortstatekey = self
-				.services
-				.short
-				.get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)
-				.await;
-
-			let event_id = &prev_event.event_id;
-			leaf_state.insert(shortstatekey, event_id.clone());
-			// Now it's the state after the pdu
-		}
-
-		let mut state = StateMap::with_capacity(leaf_state.len());
-		let mut starting_events = Vec::with_capacity(leaf_state.len());
-		for (k, id) in &leaf_state {
-			if let Ok((ty, st_key)) = self
-				.services
-				.short
-				.get_statekey_from_short(*k)
-				.await
-				.log_err()
-			{
-				// FIXME: Undo .to_string().into() when StateMap
-				//        is updated to use StateEventType
-				state.insert((ty.to_string().into(), st_key), id.clone());
-			}
-
-			starting_events.push(id.borrow());
-		}
-
-		let auth_chain: HashSet<OwnedEventId> = self
-			.services
-			.auth_chain
-			.event_ids_iter(room_id, starting_events.into_iter())
+	trace!("Calculating fork states...");
+	let (fork_states, auth_chain_sets): (Vec<StateMap<_>>, Vec<HashSet<_>>) =
+		extremity_sstatehashes
+			.into_iter()
+			.try_stream()
+			.wide_and_then(|(sstatehash, prev_event)| {
+				self.state_at_incoming_fork(room_id, sstatehash, prev_event)
+			})
 			.try_collect()
+			.map_ok(Vec::into_iter)
+			.map_ok(Iterator::unzip)
 			.await?;
 
-		auth_chain_sets.push(auth_chain);
-		fork_states.push(state);
-	}
-
 	let Ok(new_state) = self
-		.state_resolution(room_version_id, &fork_states, &auth_chain_sets)
+		.state_resolution(room_version_id, fork_states.iter(), &auth_chain_sets)
 		.boxed()
 		.await
 	else {
@@ -157,16 +117,65 @@ pub(super) async fn state_at_incoming_resolved(
 	};
 
 	new_state
-		.iter()
+		.into_iter()
 		.stream()
-		.broad_then(|((event_type, state_key), event_id)| {
+		.broad_then(|((event_type, state_key), event_id)| async move {
 			self.services
 				.short
-				.get_or_create_shortstatekey(event_type, state_key)
-				.map(move |shortstatekey| (shortstatekey, event_id.clone()))
+				.get_or_create_shortstatekey(&event_type, &state_key)
+				.map(move |shortstatekey| (shortstatekey, event_id))
+				.await
 		})
 		.collect()
 		.map(Some)
 		.map(Ok)
 		.await
 }
+
+#[implement(super::Service)]
+async fn state_at_incoming_fork(
+	&self,
+	room_id: &RoomId,
+	sstatehash: ShortStateHash,
+	prev_event: PduEvent,
+) -> Result<(StateMap<OwnedEventId>, HashSet<OwnedEventId>)> {
+	let mut leaf_state: HashMap<_, _> = self
+		.services
+		.state_accessor
+		.state_full_ids(sstatehash)
+		.collect()
+		.await;
+
+	if let Some(state_key) = &prev_event.state_key {
+		let shortstatekey = self
+			.services
+			.short
+			.get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)
+			.await;
+
+		let event_id = &prev_event.event_id;
+		leaf_state.insert(shortstatekey, event_id.clone());
+		// Now it's the state after the pdu
+	}
+
+	let auth_chain = self
+		.services
+		.auth_chain
+		.event_ids_iter(room_id, leaf_state.values().map(Borrow::borrow))
+		.try_collect();
+
+	let fork_state = leaf_state
+		.iter()
+		.stream()
+		.broad_then(|(k, id)| {
+			self.services
+				.short
+				.get_statekey_from_short(*k)
+				.map_ok(|(ty, sk)| ((ty, sk), id.clone()))
+		})
+		.ready_filter_map(Result::ok)
+		.collect()
+		.map(Ok);
+
+	try_join(fork_state, auth_chain).await
+}