From fd33f9aa79c842ef6ba21b03738f3d148121a331 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 6 Apr 2025 06:39:45 +0000 Subject: [PATCH] modernize state_res w/ stream extensions Signed-off-by: Jason Volk --- src/core/matrix/state_res/mod.rs | 220 ++++++++++++++++--------------- 1 file changed, 111 insertions(+), 109 deletions(-) diff --git a/src/core/matrix/state_res/mod.rs b/src/core/matrix/state_res/mod.rs index 93c00d15..ce6b7e89 100644 --- a/src/core/matrix/state_res/mod.rs +++ b/src/core/matrix/state_res/mod.rs @@ -15,11 +15,10 @@ use std::{ borrow::Borrow, cmp::{Ordering, Reverse}, collections::{BinaryHeap, HashMap, HashSet}, - fmt::Debug, hash::{BuildHasher, Hash}, }; -use futures::{Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt, future, stream}; +use futures::{Future, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt, future}; use ruma::{ EventId, Int, MilliSecondsSinceUnixEpoch, RoomVersionId, events::{ @@ -37,9 +36,13 @@ pub use self::{ room_version::RoomVersion, }; use crate::{ - debug, + debug, debug_error, matrix::{event::Event, pdu::StateKey}, - trace, warn, + trace, + utils::stream::{ + BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryReadyExt, WidebandExt, + }, + warn, }; /// A mapping of event type and state_key to some value `T`, usually an @@ -112,20 +115,16 @@ where debug!(count = conflicting.len(), "conflicting events"); trace!(map = ?conflicting, "conflicting events"); - let auth_chain_diff = - get_auth_chain_diff(auth_chain_sets).chain(conflicting.into_values().flatten()); + let conflicting_values = conflicting.into_values().flatten().stream(); // `all_conflicted` contains unique items // synapse says `full_set = {eid for eid in full_conflicted_set if eid in // event_map}` - let all_conflicted: HashSet<_> = stream::iter(auth_chain_diff) - // Don't honor events we cannot "verify" - .map(|id| event_exists(id.clone()).map(move |exists| (id, exists))) - .buffer_unordered(parallel_fetches) - .filter_map(|(id, exists)| future::ready(exists.then_some(id))) - .collect() - .boxed() - .await; + let all_conflicted: HashSet<_> = get_auth_chain_diff(auth_chain_sets) + .chain(conflicting_values) + .broad_filter_map(async |id| event_exists(id.clone()).await.then_some(id)) + .collect() + .await; debug!(count = all_conflicted.len(), "full conflicted set"); trace!(set = ?all_conflicted, "full conflicted set"); @@ -135,12 +134,15 @@ where // Get only the control events with a state_key: "" or ban/kick event (sender != // state_key) - let control_events: Vec<_> = stream::iter(all_conflicted.iter()) - .map(|id| is_power_event_id(id, &event_fetch).map(move |is| (id, is))) - .buffer_unordered(parallel_fetches) - .filter_map(|(id, is)| future::ready(is.then_some(id.clone()))) + let control_events: Vec<_> = all_conflicted + .iter() + .stream() + .wide_filter_map(async |id| { + is_power_event_id(id, &event_fetch) + .await + .then_some(id.clone()) + }) .collect() - .boxed() .await; // Sort the control events based on power_level/clock/event_id and @@ -160,10 +162,9 @@ where // Sequentially auth check each control event. let resolved_control = iterative_auth_check( &room_version, - sorted_control_levels.iter(), + sorted_control_levels.iter().stream(), clean.clone(), &event_fetch, - parallel_fetches, ) .await?; @@ -172,36 +173,35 @@ where // At this point the control_events have been resolved we now have to // sort the remaining events using the mainline of the resolved power level. - let deduped_power_ev = sorted_control_levels.into_iter().collect::>(); + let deduped_power_ev: HashSet<_> = sorted_control_levels.into_iter().collect(); // This removes the control events that passed auth and more importantly those // that failed auth - let events_to_resolve = all_conflicted + let events_to_resolve: Vec<_> = all_conflicted .iter() .filter(|&id| !deduped_power_ev.contains(id.borrow())) .cloned() - .collect::>(); + .collect(); debug!(count = events_to_resolve.len(), "events left to resolve"); trace!(list = ?events_to_resolve, "events left to resolve"); // This "epochs" power level event - let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, StateKey::new())); + let power_levels_ty_sk = (StateEventType::RoomPowerLevels, StateKey::new()); + let power_event = resolved_control.get(&power_levels_ty_sk); debug!(event_id = ?power_event, "power event"); let sorted_left_events = - mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch, parallel_fetches) - .await?; + mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch).await?; trace!(list = ?sorted_left_events, "events left, sorted"); let mut resolved_state = iterative_auth_check( &room_version, - sorted_left_events.iter(), + sorted_left_events.iter().stream(), resolved_control, // The control events are added to the final resolved state &event_fetch, - parallel_fetches, ) .await?; @@ -265,7 +265,7 @@ where #[allow(clippy::arithmetic_side_effects)] fn get_auth_chain_diff( auth_chain_sets: &[HashSet], -) -> impl Iterator + Send + use +) -> impl Stream + Send + use where Id: Clone + Eq + Hash + Send, Hasher: BuildHasher + Send + Sync, @@ -279,6 +279,7 @@ where id_counts .into_iter() .filter_map(move |(id, count)| (count < num_sets).then_some(id)) + .stream() } /// Events are sorted from "earliest" to "latest". @@ -310,13 +311,15 @@ where } // This is used in the `key_fn` passed to the lexico_topo_sort fn - let event_to_pl = stream::iter(graph.keys()) + let event_to_pl = graph + .keys() + .stream() .map(|event_id| { - get_power_level_for_sender(event_id.clone(), fetch_event, parallel_fetches) + get_power_level_for_sender(event_id.clone(), fetch_event) .map(move |res| res.map(|pl| (event_id, pl))) }) .buffer_unordered(parallel_fetches) - .try_fold(HashMap::new(), |mut event_to_pl, (event_id, pl)| { + .ready_try_fold(HashMap::new(), |mut event_to_pl, (event_id, pl)| { debug!( event_id = event_id.borrow().as_str(), power_level = i64::from(pl), @@ -324,7 +327,7 @@ where ); event_to_pl.insert(event_id.clone(), pl); - future::ok(event_to_pl) + Ok(event_to_pl) }) .boxed() .await?; @@ -475,7 +478,6 @@ where async fn get_power_level_for_sender( event_id: E::Id, fetch_event: &F, - parallel_fetches: usize, ) -> serde_json::Result where F: Fn(E::Id) -> Fut + Sync, @@ -485,19 +487,17 @@ where { debug!("fetch event ({event_id}) senders power level"); - let event = fetch_event(event_id.clone()).await; + let event = fetch_event(event_id).await; - let auth_events = event.as_ref().map(Event::auth_events).into_iter().flatten(); + let auth_events = event.as_ref().map(Event::auth_events); - let pl = stream::iter(auth_events) - .map(|aid| fetch_event(aid.clone())) - .buffer_unordered(parallel_fetches.min(5)) - .filter_map(future::ready) - .collect::>() - .boxed() - .await + let pl = auth_events .into_iter() - .find(|aev| is_type_and_key(aev, &TimelineEventType::RoomPowerLevels, "")); + .flatten() + .stream() + .broadn_filter_map(5, |aid| fetch_event(aid.clone())) + .ready_find(|aev| is_type_and_key(aev, &TimelineEventType::RoomPowerLevels, "")) + .await; let content: PowerLevelsContentFields = match pl { | None => return Ok(int!(0)), @@ -525,34 +525,28 @@ where /// For each `events_to_check` event we gather the events needed to auth it from /// the the `fetch_event` closure and verify each event using the /// `event_auth::auth_check` function. -async fn iterative_auth_check<'a, E, F, Fut, I>( +async fn iterative_auth_check<'a, E, F, Fut, S>( room_version: &RoomVersion, - events_to_check: I, + events_to_check: S, unconflicted_state: StateMap, fetch_event: &F, - parallel_fetches: usize, ) -> Result> where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E::Id: Borrow + Clone + Eq + Ord + Send + Sync + 'a, - I: Iterator + Debug + Send + 'a, + S: Stream + Send + 'a, E: Event + Clone + Send + Sync, { debug!("starting iterative auth check"); - trace!( - list = ?events_to_check, - "events to check" - ); - let events_to_check: Vec<_> = stream::iter(events_to_check) + let events_to_check: Vec<_> = events_to_check .map(Result::Ok) - .map_ok(|event_id| { - fetch_event(event_id.clone()).map(move |result| { - result.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}"))) - }) + .broad_and_then(async |event_id| { + fetch_event(event_id.clone()) + .await + .ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}"))) }) - .try_buffer_unordered(parallel_fetches) .try_collect() .boxed() .await?; @@ -562,10 +556,10 @@ where .flat_map(|event: &E| event.auth_events().map(Clone::clone)) .collect(); - let auth_events: HashMap = stream::iter(auth_event_ids.into_iter()) - .map(fetch_event) - .buffer_unordered(parallel_fetches) - .filter_map(future::ready) + let auth_events: HashMap = auth_event_ids + .into_iter() + .stream() + .broad_filter_map(fetch_event) .map(|auth_event| (auth_event.event_id().clone(), auth_event)) .collect() .boxed() @@ -574,7 +568,6 @@ where let auth_events = &auth_events; let mut resolved_state = unconflicted_state; for event in &events_to_check { - let event_id = event.event_id(); let state_key = event .state_key() .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; @@ -603,24 +596,22 @@ where } } - stream::iter( - auth_types - .iter() - .filter_map(|key| Some((key, resolved_state.get(key)?))), - ) - .filter_map(|(key, ev_id)| async move { - if let Some(event) = auth_events.get(ev_id.borrow()) { - Some((key, event.clone())) - } else { - Some((key, fetch_event(ev_id.clone()).await?)) - } - }) - .for_each(|(key, event)| { - //TODO: synapse checks "rejected_reason" is None here - auth_state.insert(key.to_owned(), event); - future::ready(()) - }) - .await; + auth_types + .iter() + .stream() + .ready_filter_map(|key| Some((key, resolved_state.get(key)?))) + .filter_map(|(key, ev_id)| async move { + if let Some(event) = auth_events.get(ev_id.borrow()) { + Some((key, event.clone())) + } else { + Some((key, fetch_event(ev_id.clone()).await?)) + } + }) + .ready_for_each(|(key, event)| { + //TODO: synapse checks "rejected_reason" is None here + auth_state.insert(key.to_owned(), event); + }) + .await; debug!("event to check {:?}", event.event_id()); @@ -634,12 +625,25 @@ where future::ready(auth_state.get(&ty.with_state_key(key))) }; - if auth_check(room_version, &event, current_third_party.as_ref(), fetch_state).await? { - // add event to resolved state map - resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone()); - } else { - // synapse passes here on AuthError. We do not add this event to resolved_state. - warn!("event {event_id} failed the authentication check"); + let auth_result = + auth_check(room_version, &event, current_third_party.as_ref(), fetch_state).await; + + match auth_result { + | Ok(true) => { + // add event to resolved state map + resolved_state.insert( + event.event_type().with_state_key(state_key), + event.event_id().clone(), + ); + }, + | Ok(false) => { + // synapse passes here on AuthError. We do not add this event to resolved_state. + warn!("event {} failed the authentication check", event.event_id()); + }, + | Err(e) => { + debug_error!("event {} failed the authentication check: {e}", event.event_id()); + return Err(e); + }, } } @@ -659,7 +663,6 @@ async fn mainline_sort( to_sort: &[E::Id], resolved_power_level: Option, fetch_event: &F, - parallel_fetches: usize, ) -> Result> where F: Fn(E::Id) -> Fut + Sync, @@ -682,11 +685,13 @@ where let event = fetch_event(p.clone()) .await .ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?; + pl = None; for aid in event.auth_events() { let ev = fetch_event(aid.clone()) .await .ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?; + if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") { pl = Some(aid.to_owned()); break; @@ -694,36 +699,32 @@ where } } - let mainline_map = mainline + let mainline_map: HashMap<_, _> = mainline .iter() .rev() .enumerate() .map(|(idx, eid)| ((*eid).clone(), idx)) - .collect::>(); + .collect(); - let order_map = stream::iter(to_sort.iter()) - .map(|ev_id| { - fetch_event(ev_id.clone()).map(move |event| event.map(|event| (event, ev_id))) + let order_map: HashMap<_, _> = to_sort + .iter() + .stream() + .broad_filter_map(async |ev_id| { + fetch_event(ev_id.clone()).await.map(|event| (event, ev_id)) }) - .buffer_unordered(parallel_fetches) - .filter_map(future::ready) - .map(|(event, ev_id)| { + .broad_filter_map(|(event, ev_id)| { get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event) - .map_ok(move |depth| (depth, event, ev_id)) + .map_ok(move |depth| (ev_id, (depth, event.origin_server_ts(), ev_id))) .map(Result::ok) }) - .buffer_unordered(parallel_fetches) - .filter_map(future::ready) - .fold(HashMap::new(), |mut order_map, (depth, event, ev_id)| { - order_map.insert(ev_id, (depth, event.origin_server_ts(), ev_id)); - future::ready(order_map) - }) + .collect() .boxed() .await; // Sort the event_ids by their depth, timestamp and EventId // unwrap is OK order map and sort_event_ids are from to_sort (the same Vec) - let mut sort_event_ids = order_map.keys().map(|&k| k.clone()).collect::>(); + let mut sort_event_ids: Vec<_> = order_map.keys().map(|&k| k.clone()).collect(); + sort_event_ids.sort_by_key(|sort_id| &order_map[sort_id]); Ok(sort_event_ids) @@ -744,6 +745,7 @@ where { while let Some(sort_ev) = event { debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline"); + let id = sort_ev.event_id(); if let Some(depth) = mainline_map.get(id.borrow()) { return Ok(*depth); @@ -754,6 +756,7 @@ where let aev = fetch_event(aid.clone()) .await .ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?; + if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") { event = Some(aev); break; @@ -884,7 +887,7 @@ mod tests { zara, }, }; - use crate::debug; + use crate::{debug, utils::stream::IterStream}; async fn test_event_sort() { use futures::future::ready; @@ -915,10 +918,9 @@ mod tests { let resolved_power = super::iterative_auth_check( &RoomVersion::V6, - sorted_power_events.iter(), + sorted_power_events.iter().stream(), HashMap::new(), // unconflicted events &fetcher, - 1, ) .await .expect("iterative auth check failed on resolved events"); @@ -932,7 +934,7 @@ mod tests { .get(&(StateEventType::RoomPowerLevels, "".into())) .cloned(); - let sorted_event_ids = super::mainline_sort(&events_to_sort, power_level, &fetcher, 1) + let sorted_event_ids = super::mainline_sort(&events_to_sort, power_level, &fetcher) .await .unwrap();