From f2ca670c3b0858675312be60dcfb971384ce1244 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 8 Feb 2025 01:58:13 +0000 Subject: [PATCH] optimize further into state-res with SmallString triage and de-lints for state-res. Signed-off-by: Jason Volk --- Cargo.lock | 1 + Cargo.toml | 4 + src/api/client/membership.rs | 8 +- src/api/client/sync/v4.rs | 15 +- src/api/client/sync/v5.rs | 15 +- src/core/Cargo.toml | 3 + src/core/state_res/event_auth.rs | 142 ++++++++++-------- src/core/state_res/mod.rs | 89 ++++++----- src/core/state_res/room_version.rs | 1 + src/core/state_res/test_utils.rs | 43 +++--- .../rooms/event_handler/handle_outlier_pdu.rs | 2 +- .../rooms/event_handler/resolve_state.rs | 1 - .../rooms/event_handler/state_at_incoming.rs | 1 - .../event_handler/upgrade_outlier_pdu.rs | 10 +- src/service/rooms/timeline/mod.rs | 2 +- 15 files changed, 192 insertions(+), 145 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 5981a2a6..4441779e 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -810,6 +810,7 @@ dependencies = [ "libc", "libloading", "log", + "maplit", "nix", "num-traits", "rand", diff --git a/Cargo.toml b/Cargo.toml index d8f34544..a17aa4d6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -379,6 +379,7 @@ features = [ "unstable-msc4203", # sending to-device events to appservices "unstable-msc4210", # remove legacy mentions "unstable-extensible-events", + "unstable-pdu", ] [workspace.dependencies.rust-rocksdb] @@ -527,6 +528,9 @@ features = ["std"] version = "0.3.2" features = ["std"] +[workspace.dependencies.maplit] +version = "1.0.2" + # # Patches # diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 1045b014..6c970665 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -14,7 +14,7 @@ use conduwuit::{ result::FlatOk, state_res, trace, utils::{self, shuffle, IterStream, ReadyExt}, - warn, Err, PduEvent, Result, + warn, Err, PduEvent, Result, StateKey, }; use futures::{join, FutureExt, StreamExt, TryFutureExt}; use ruma::{ @@ -1151,8 +1151,8 @@ async fn join_room_by_id_helper_remote( debug!("Running send_join auth check"); let fetch_state = &state; - let state_fetch = |k: &'static StateEventType, s: String| async move { - let shortstatekey = services.rooms.short.get_shortstatekey(k, &s).await.ok()?; + let state_fetch = |k: StateEventType, s: StateKey| async move { + let shortstatekey = services.rooms.short.get_shortstatekey(&k, &s).await.ok()?; let event_id = fetch_state.get(&shortstatekey)?; services.rooms.timeline.get_pdu(event_id).await.ok() @@ -1162,7 +1162,7 @@ async fn join_room_by_id_helper_remote( &state_res::RoomVersion::new(&room_version_id)?, &parsed_join_pdu, None, // TODO: third party invite - |k, s| state_fetch(k, s.to_owned()), + |k, s| state_fetch(k.clone(), s.into()), ) .await .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index 4e474ef3..13f832b2 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -395,9 +395,12 @@ pub(crate) async fn sync_events_v4_route( .map_or(10, usize_from_u64_truncated) .min(100); - todo_room - .0 - .extend(list.room_details.required_state.iter().cloned()); + todo_room.0.extend( + list.room_details + .required_state + .iter() + .map(|(ty, sk)| (ty.clone(), sk.as_str().into())), + ); todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date @@ -449,7 +452,11 @@ pub(crate) async fn sync_events_v4_route( .map_or(10, usize_from_u64_truncated) .min(100); - todo_room.0.extend(room.required_state.iter().cloned()); + todo_room.0.extend( + room.required_state + .iter() + .map(|(ty, sk)| (ty.clone(), sk.as_str().into())), + ); todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date todo_room.2 = todo_room.2.min( diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index 63731688..cda6c041 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -223,7 +223,11 @@ async fn fetch_subscriptions( let limit: UInt = room.timeline_limit; - todo_room.0.extend(room.required_state.iter().cloned()); + todo_room.0.extend( + room.required_state + .iter() + .map(|(ty, sk)| (ty.clone(), sk.as_str().into())), + ); todo_room.1 = todo_room.1.max(usize_from_ruma(limit)); // 0 means unknown because it got out of date todo_room.2 = todo_room.2.min( @@ -303,9 +307,12 @@ async fn handle_lists<'a>( let limit: usize = usize_from_ruma(list.room_details.timeline_limit).min(100); - todo_room - .0 - .extend(list.room_details.required_state.iter().cloned()); + todo_room.0.extend( + list.room_details + .required_state + .iter() + .map(|(ty, sk)| (ty.clone(), sk.as_str().into())), + ); todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date diff --git a/src/core/Cargo.toml b/src/core/Cargo.toml index d4b0c83b..b40dd3ad 100644 --- a/src/core/Cargo.toml +++ b/src/core/Cargo.toml @@ -116,5 +116,8 @@ nix.workspace = true hardened_malloc-rs.workspace = true hardened_malloc-rs.optional = true +[dev-dependencies] +maplit.workspace = true + [lints] workspace = true diff --git a/src/core/state_res/event_auth.rs b/src/core/state_res/event_auth.rs index 72a0216c..df2f8b36 100644 --- a/src/core/state_res/event_auth.rs +++ b/src/core/state_res/event_auth.rs @@ -21,7 +21,6 @@ use serde::{ Deserialize, }; use serde_json::{from_str as from_json_str, value::RawValue as RawJsonValue}; -use tracing::{debug, error, instrument, trace, warn}; use super::{ power_levels::{ @@ -29,8 +28,9 @@ use super::{ deserialize_power_levels_content_invite, deserialize_power_levels_content_redact, }, room_version::RoomVersion, - Error, Event, Result, StateEventType, TimelineEventType, + Error, Event, Result, StateEventType, StateKey, TimelineEventType, }; +use crate::{debug, error, trace, warn}; // FIXME: field extracting could be bundled for `content` #[derive(Deserialize)] @@ -56,15 +56,15 @@ pub fn auth_types_for_event( sender: &UserId, state_key: Option<&str>, content: &RawJsonValue, -) -> serde_json::Result> { +) -> serde_json::Result> { if kind == &TimelineEventType::RoomCreate { return Ok(vec![]); } let mut auth_types = vec![ - (StateEventType::RoomPowerLevels, String::new()), - (StateEventType::RoomMember, sender.to_string()), - (StateEventType::RoomCreate, String::new()), + (StateEventType::RoomPowerLevels, StateKey::new()), + (StateEventType::RoomMember, sender.as_str().into()), + (StateEventType::RoomCreate, StateKey::new()), ]; if kind == &TimelineEventType::RoomMember { @@ -82,7 +82,7 @@ pub fn auth_types_for_event( if [MembershipState::Join, MembershipState::Invite, MembershipState::Knock] .contains(&membership) { - let key = (StateEventType::RoomJoinRules, String::new()); + let key = (StateEventType::RoomJoinRules, StateKey::new()); if !auth_types.contains(&key) { auth_types.push(key); } @@ -91,21 +91,22 @@ pub fn auth_types_for_event( .join_authorised_via_users_server .map(|m| m.deserialize()) { - let key = (StateEventType::RoomMember, u.to_string()); + let key = (StateEventType::RoomMember, u.as_str().into()); if !auth_types.contains(&key) { auth_types.push(key); } } } - let key = (StateEventType::RoomMember, state_key.to_owned()); + let key = (StateEventType::RoomMember, state_key.into()); if !auth_types.contains(&key) { auth_types.push(key); } if membership == MembershipState::Invite { if let Some(Ok(t_id)) = content.third_party_invite.map(|t| t.deserialize()) { - let key = (StateEventType::RoomThirdPartyInvite, t_id.signed.token); + let key = + (StateEventType::RoomThirdPartyInvite, t_id.signed.token.into()); if !auth_types.contains(&key) { auth_types.push(key); } @@ -128,7 +129,13 @@ pub fn auth_types_for_event( /// The `fetch_state` closure should gather state from a state snapshot. We need /// to know if the event passes auth against some state not a recursive /// collection of auth_events fields. -#[instrument(level = "debug", skip_all, fields(event_id = incoming_event.event_id().borrow().as_str()))] +#[tracing::instrument( + level = "debug", + skip_all, + fields( + event_id = incoming_event.event_id().borrow().as_str() + ) +)] pub async fn auth_check( room_version: &RoomVersion, incoming_event: &Incoming, @@ -136,10 +143,10 @@ pub async fn auth_check( fetch_state: F, ) -> Result where - F: Fn(&'static StateEventType, &str) -> Fut, + F: Fn(&StateEventType, &str) -> Fut + Send, Fut: Future> + Send, Fetched: Event + Send, - Incoming: Event + Send, + Incoming: Event + Send + Sync, { debug!( "auth_check beginning for {} ({})", @@ -262,6 +269,7 @@ where // sender domain of the event does not match the sender domain of the create // event, reject. #[derive(Deserialize)] + #[allow(clippy::items_after_statements)] struct RoomCreateContentFederate { #[serde(rename = "m.federate", default = "ruma::serde::default_true")] federate: bool, @@ -354,7 +362,7 @@ where join_rules_event.as_ref(), user_for_join_auth.as_deref(), &user_for_join_auth_membership, - room_create_event, + &room_create_event, )? { return Ok(false); } @@ -364,6 +372,7 @@ where } // If the sender's current membership state is not join, reject + #[allow(clippy::manual_let_else)] let sender_member_event = match sender_member_event { | Some(mem) => mem, | None => { @@ -498,19 +507,20 @@ where /// This is generated by calling `auth_types_for_event` with the membership /// event and the current State. #[allow(clippy::too_many_arguments)] +#[allow(clippy::cognitive_complexity)] fn valid_membership_change( room_version: &RoomVersion, target_user: &UserId, - target_user_membership_event: Option, + target_user_membership_event: Option<&impl Event>, sender: &UserId, - sender_membership_event: Option, + sender_membership_event: Option<&impl Event>, current_event: impl Event, - current_third_party_invite: Option, - power_levels_event: Option, - join_rules_event: Option, + current_third_party_invite: Option<&impl Event>, + power_levels_event: Option<&impl Event>, + join_rules_event: Option<&impl Event>, user_for_join_auth: Option<&UserId>, user_for_join_auth_membership: &MembershipState, - create_room: impl Event, + create_room: &impl Event, ) -> Result { #[derive(Deserialize)] struct GetThirdPartyInvite { @@ -856,6 +866,7 @@ fn check_power_levels( // and integers here debug!("validation of power event finished"); + #[allow(clippy::manual_let_else)] let current_state = match previous_power_event { | Some(current_state) => current_state, // If there is no previous m.room.power_levels event in the room, allow @@ -1054,6 +1065,7 @@ fn verify_third_party_invite( // If there is no m.room.third_party_invite event in the current room state with // state_key matching token, reject + #[allow(clippy::manual_let_else)] let current_tpid = match current_third_party_invite { | Some(id) => id, | None => return false, @@ -1069,12 +1081,14 @@ fn verify_third_party_invite( // If any signature in signed matches any public key in the // m.room.third_party_invite event, allow + #[allow(clippy::manual_let_else)] let tpid_ev = match from_json_str::(current_tpid.content().get()) { | Ok(ev) => ev, | Err(_) => return false, }; + #[allow(clippy::manual_let_else)] let decoded_invite_token = match Base64::parse(&tp_id.signed.token) { | Ok(tok) => tok, // FIXME: Log a warning? @@ -1096,7 +1110,7 @@ fn verify_third_party_invite( mod tests { use std::sync::Arc; - use ruma_events::{ + use ruma::events::{ room::{ join_rules::{ AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership, @@ -1107,7 +1121,7 @@ mod tests { }; use serde_json::value::to_raw_value as to_raw_json_value; - use crate::{ + use crate::state_res::{ event_auth::valid_membership_change, test_utils::{ alice, charlie, ella, event_id, member_content_ban, member_content_join, room_id, @@ -1145,16 +1159,16 @@ mod tests { assert!(valid_membership_change( &RoomVersion::V6, target_user, - fetch_state(StateEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(), sender, - fetch_state(StateEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(), &requester, - None::, - fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), - fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None::<&PduEvent>, + fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(), + fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(), None, &MembershipState::Leave, - fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(), ) .unwrap()); } @@ -1188,16 +1202,16 @@ mod tests { assert!(!valid_membership_change( &RoomVersion::V6, target_user, - fetch_state(StateEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(), sender, - fetch_state(StateEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(), &requester, - None::, - fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), - fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None::<&PduEvent>, + fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(), + fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(), None, &MembershipState::Leave, - fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(), ) .unwrap()); } @@ -1231,16 +1245,16 @@ mod tests { assert!(valid_membership_change( &RoomVersion::V6, target_user, - fetch_state(StateEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(), sender, - fetch_state(StateEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(), &requester, - None::, - fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), - fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None::<&PduEvent>, + fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(), + fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(), None, &MembershipState::Leave, - fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(), ) .unwrap()); } @@ -1274,16 +1288,16 @@ mod tests { assert!(!valid_membership_change( &RoomVersion::V6, target_user, - fetch_state(StateEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(), sender, - fetch_state(StateEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(), &requester, - None::, - fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), - fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None::<&PduEvent>, + fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(), + fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(), None, &MembershipState::Leave, - fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(), ) .unwrap()); } @@ -1334,32 +1348,32 @@ mod tests { assert!(valid_membership_change( &RoomVersion::V9, target_user, - fetch_state(StateEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(), sender, - fetch_state(StateEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(), &requester, - None::, - fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), - fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None::<&PduEvent>, + fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(), + fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(), Some(alice()), &MembershipState::Join, - fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(), ) .unwrap()); assert!(!valid_membership_change( &RoomVersion::V9, target_user, - fetch_state(StateEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(), sender, - fetch_state(StateEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(), &requester, - None::, - fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), - fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None::<&PduEvent>, + fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(), + fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(), Some(ella()), &MembershipState::Leave, - fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(), ) .unwrap()); } @@ -1402,16 +1416,16 @@ mod tests { assert!(valid_membership_change( &RoomVersion::V7, target_user, - fetch_state(StateEventType::RoomMember, target_user.to_string()), + fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(), sender, - fetch_state(StateEventType::RoomMember, sender.to_string()), + fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(), &requester, - None::, - fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), - fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None::<&PduEvent>, + fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(), + fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(), None, &MembershipState::Leave, - fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(), ) .unwrap()); } diff --git a/src/core/state_res/mod.rs b/src/core/state_res/mod.rs index e4054377..19ea3cc0 100644 --- a/src/core/state_res/mod.rs +++ b/src/core/state_res/mod.rs @@ -1,3 +1,5 @@ +#![cfg_attr(test, allow(warnings))] + pub(crate) mod error; pub mod event_auth; mod power_levels; @@ -12,7 +14,7 @@ use std::{ cmp::{Ordering, Reverse}, collections::{BinaryHeap, HashMap, HashSet}, fmt::Debug, - hash::Hash, + hash::{BuildHasher, Hash}, }; use futures::{future, stream, Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; @@ -32,13 +34,13 @@ pub use self::{ room_version::RoomVersion, state_event::Event, }; -use crate::{debug, trace, warn}; +use crate::{debug, pdu::StateKey, trace, warn}; /// A mapping of event type and state_key to some value `T`, usually an /// `EventId`. pub type StateMap = HashMap; pub type StateMapItem = (TypeStateKey, T); -pub type TypeStateKey = (StateEventType, String); +pub type TypeStateKey = (StateEventType, StateKey); type Result = crate::Result; @@ -68,10 +70,10 @@ type Result = crate::Result; /// event is part of the same room. //#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets, //#[tracing::instrument(level event_fetch))] -pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>( +pub async fn resolve<'a, E, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>( room_version: &RoomVersionId, - state_sets: impl IntoIterator + Send, - auth_chain_sets: &'a [HashSet], + state_sets: Sets, + auth_chain_sets: &'a [HashSet], event_fetch: &Fetch, event_exists: &Exists, parallel_fetches: usize, @@ -81,7 +83,9 @@ where FetchFut: Future> + Send, Exists: Fn(E::Id) -> ExistsFut + Sync, ExistsFut: Future + Send, + Sets: IntoIterator + Send, SetIter: Iterator> + Clone + Send, + Hasher: BuildHasher + Send + Sync, E: Event + Clone + Send + Sync, E::Id: Borrow + Send + Sync, for<'b> &'b E: Send, @@ -178,7 +182,7 @@ where trace!(list = ?events_to_resolve, "events left to resolve"); // This "epochs" power level event - let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, String::new())); + let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, StateKey::new())); debug!(event_id = ?power_event, "power event"); @@ -222,16 +226,17 @@ fn separate<'a, Id>( where Id: Clone + Eq + Hash + 'a, { - let mut state_set_count = 0_usize; + let mut state_set_count: usize = 0; let mut occurrences = HashMap::<_, HashMap<_, _>>::new(); - let state_sets_iter = state_sets_iter.inspect(|_| state_set_count += 1); + let state_sets_iter = + state_sets_iter.inspect(|_| state_set_count = state_set_count.saturating_add(1)); for (k, v) in state_sets_iter.flatten() { occurrences .entry(k) .or_default() .entry(v) - .and_modify(|x| *x += 1) + .and_modify(|x: &mut usize| *x = x.saturating_add(1)) .or_insert(1); } @@ -246,7 +251,7 @@ where conflicted_state .entry((k.0.clone(), k.1.clone())) .and_modify(|x: &mut Vec<_>| x.push(id.clone())) - .or_insert(vec![id.clone()]); + .or_insert_with(|| vec![id.clone()]); } } } @@ -255,9 +260,13 @@ where } /// Returns a Vec of deduped EventIds that appear in some chains but not others. -fn get_auth_chain_diff(auth_chain_sets: &[HashSet]) -> impl Iterator + Send +#[allow(clippy::arithmetic_side_effects)] +fn get_auth_chain_diff( + auth_chain_sets: &[HashSet], +) -> impl Iterator + Send where Id: Clone + Eq + Hash + Send, + Hasher: BuildHasher + Send + Sync, { let num_sets = auth_chain_sets.len(); let mut id_counts: HashMap = HashMap::new(); @@ -288,7 +297,7 @@ async fn reverse_topological_power_sort( where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, - E: Event + Send, + E: Event + Send + Sync, E::Id: Borrow + Send + Sync, { debug!("reverse topological sort of power events"); @@ -337,14 +346,15 @@ where /// `key_fn` is used as to obtain the power level and age of an event for /// breaking ties (together with the event ID). #[tracing::instrument(level = "debug", skip_all)] -pub async fn lexicographical_topological_sort( - graph: &HashMap>, +pub async fn lexicographical_topological_sort( + graph: &HashMap>, key_fn: &F, ) -> Result> where F: Fn(Id) -> Fut + Sync, Fut: Future> + Send, - Id: Borrow + Clone + Eq + Hash + Ord + Send, + Id: Borrow + Clone + Eq + Hash + Ord + Send + Sync, + Hasher: BuildHasher + Default + Clone + Send + Sync, { #[derive(PartialEq, Eq)] struct TieBreaker<'a, Id> { @@ -395,7 +405,7 @@ where // The number of events that depend on the given event (the EventId key) // How many events reference this event in the DAG as a parent - let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new(); + let mut reverse_graph: HashMap<_, HashSet<_, Hasher>> = HashMap::new(); // Vec of nodes that have zero out degree, least recent events. let mut zero_outdegree = Vec::new(); @@ -727,8 +737,8 @@ async fn get_mainline_depth( where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, - E: Event + Send, - E::Id: Borrow + Send, + E: Event + Send + Sync, + E::Id: Borrow + Send + Sync, { while let Some(sort_ev) = event { debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline"); @@ -758,10 +768,10 @@ async fn add_event_and_auth_chain_to_graph( auth_diff: &HashSet, fetch_event: &F, ) where - F: Fn(E::Id) -> Fut, + F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, - E: Event + Send, - E::Id: Borrow + Clone + Send, + E: Event + Send + Sync, + E::Id: Borrow + Clone + Send + Sync, { let mut state = vec![event_id]; while let Some(eid) = state.pop() { @@ -788,7 +798,7 @@ where F: Fn(E::Id) -> Fut + Sync, Fut: Future> + Send, E: Event + Send, - E::Id: Borrow + Send, + E::Id: Borrow + Send + Sync, { match fetch(event_id.clone()).await.as_ref() { | Some(state) => is_power_event(state), @@ -820,18 +830,18 @@ fn is_power_event(event: impl Event) -> bool { /// Convenience trait for adding event type plus state key to state maps. pub trait EventTypeExt { - fn with_state_key(self, state_key: impl Into) -> (StateEventType, String); + fn with_state_key(self, state_key: impl Into) -> (StateEventType, StateKey); } impl EventTypeExt for StateEventType { - fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, StateKey) { (self, state_key.into()) } } impl EventTypeExt for TimelineEventType { - fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { - (self.to_string().into(), state_key.into()) + fn with_state_key(self, state_key: impl Into) -> (StateEventType, StateKey) { + (self.into(), state_key.into()) } } @@ -839,7 +849,7 @@ impl EventTypeExt for &T where T: EventTypeExt + Clone, { - fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, StateKey) { self.to_owned().with_state_key(state_key) } } @@ -858,13 +868,11 @@ mod tests { room::join_rules::{JoinRule, RoomJoinRulesEventContent}, StateEventType, TimelineEventType, }, - int, uint, + int, uint, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId, }; - use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId}; use serde_json::{json, value::to_raw_value as to_raw_json_value}; - use tracing::debug; - use crate::{ + use super::{ is_power_event, room_version::RoomVersion, test_utils::{ @@ -874,6 +882,7 @@ mod tests { }, Event, EventTypeExt, StateMap, }; + use crate::debug; async fn test_event_sort() { use futures::future::ready; @@ -898,11 +907,11 @@ mod tests { let fetcher = |id| ready(events.get(&id).cloned()); let sorted_power_events = - crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1) + super::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1) .await .unwrap(); - let resolved_power = crate::iterative_auth_check( + let resolved_power = super::iterative_auth_check( &RoomVersion::V6, sorted_power_events.iter(), HashMap::new(), // unconflicted events @@ -918,10 +927,10 @@ mod tests { events_to_sort.shuffle(&mut rand::thread_rng()); let power_level = resolved_power - .get(&(StateEventType::RoomPowerLevels, "".to_owned())) + .get(&(StateEventType::RoomPowerLevels, "".into())) .cloned(); - let sorted_event_ids = crate::mainline_sort(&events_to_sort, power_level, &fetcher, 1) + let sorted_event_ids = super::mainline_sort(&events_to_sort, power_level, &fetcher, 1) .await .unwrap(); @@ -1302,7 +1311,7 @@ mod tests { }) .collect(); - let resolved = match crate::resolve( + let resolved = match super::resolve( &RoomVersionId::V2, &state_sets, &auth_chain, @@ -1333,7 +1342,7 @@ mod tests { event_id("p") => hashset![event_id("o")], }; - let res = crate::lexicographical_topological_sort(&graph, &|_id| async { + let res = super::lexicographical_topological_sort(&graph, &|_id| async { Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) }) .await @@ -1421,7 +1430,7 @@ mod tests { let fetcher = |id: ::Id| ready(ev_map.get(&id).cloned()); let exists = |id: ::Id| ready(ev_map.get(&id).is_some()); - let resolved = match crate::resolve( + let resolved = match super::resolve( &RoomVersionId::V6, &state_sets, &auth_chain, @@ -1552,7 +1561,7 @@ mod tests { #[allow(unused_mut)] let mut x = StateMap::new(); $( - x.insert(($kind, $key.to_owned()), $id); + x.insert(($kind, $key.into()), $id); )* x }}; diff --git a/src/core/state_res/room_version.rs b/src/core/state_res/room_version.rs index e1b0afe1..8dfd6cde 100644 --- a/src/core/state_res/room_version.rs +++ b/src/core/state_res/room_version.rs @@ -32,6 +32,7 @@ pub enum StateResolutionVersion { } #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +#[allow(clippy::struct_excessive_bools)] pub struct RoomVersion { /// The stability of this room. pub disposition: RoomDisposition, diff --git a/src/core/state_res/test_utils.rs b/src/core/state_res/test_utils.rs index 7954b28d..9c2b151f 100644 --- a/src/core/state_res/test_utils.rs +++ b/src/core/state_res/test_utils.rs @@ -7,28 +7,28 @@ use std::{ }, }; -use futures_util::future::ready; -use js_int::{int, uint}; -use ruma_common::{ - event_id, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, - RoomVersionId, ServerSignatures, UserId, -}; -use ruma_events::{ - pdu::{EventHash, Pdu, RoomV3Pdu}, - room::{ - join_rules::{JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, +use futures::future::ready; +use ruma::{ + event_id, + events::{ + pdu::{EventHash, Pdu, RoomV3Pdu}, + room::{ + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + }, + TimelineEventType, }, - TimelineEventType, + int, room_id, uint, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, + RoomVersionId, ServerSignatures, UserId, }; use serde_json::{ json, value::{to_raw_value as to_raw_json_value, RawValue as RawJsonValue}, }; -use tracing::info; pub(crate) use self::event::PduEvent; -use crate::{auth_types_for_event, Error, Event, EventTypeExt, Result, StateMap}; +use super::auth_types_for_event; +use crate::{info, Event, EventTypeExt, Result, StateMap}; static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0); @@ -88,7 +88,7 @@ pub(crate) async fn do_check( // Resolve the current state and add it to the state_at_event map then continue // on in "time" - for node in crate::lexicographical_topological_sort(&graph, &|_id| async { + for node in super::lexicographical_topological_sort(&graph, &|_id| async { Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) }) .await @@ -135,7 +135,7 @@ pub(crate) async fn do_check( let event_map = &event_map; let fetch = |id: ::Id| ready(event_map.get(&id).cloned()); let exists = |id: ::Id| ready(event_map.get(&id).is_some()); - let resolved = crate::resolve( + let resolved = super::resolve( &RoomVersionId::V6, state_sets, &auth_chain_sets, @@ -223,7 +223,7 @@ pub(crate) async fn do_check( // Filter out the dummy messages events. // These act as points in time where there should be a known state to // test against. - && **k != ("m.room.message".into(), "dummy".to_owned()) + && **k != ("m.room.message".into(), "dummy".into()) }) .map(|(k, v)| (k.clone(), v.clone())) .collect::>(); @@ -239,7 +239,8 @@ impl TestStore { self.0 .get(event_id) .cloned() - .ok_or_else(|| Error::NotFound(format!("{event_id} not found"))) + .ok_or_else(|| super::Error::NotFound(format!("{event_id} not found"))) + .map_err(Into::into) } /// Returns a Vec of the related auth events to the given `event`. @@ -582,8 +583,10 @@ pub(crate) fn INITIAL_EDGES() -> Vec { } pub(crate) mod event { - use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId}; - use ruma_events::{pdu::Pdu, TimelineEventType}; + use ruma::{ + events::{pdu::Pdu, TimelineEventType}, + MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId, + }; use serde::{Deserialize, Serialize}; use serde_json::value::RawValue as RawJsonValue; diff --git a/src/service/rooms/event_handler/handle_outlier_pdu.rs b/src/service/rooms/event_handler/handle_outlier_pdu.rs index 3cc15fc4..e628c77a 100644 --- a/src/service/rooms/event_handler/handle_outlier_pdu.rs +++ b/src/service/rooms/event_handler/handle_outlier_pdu.rs @@ -133,7 +133,7 @@ pub(super) async fn handle_outlier_pdu<'a>( )); } - let state_fetch = |ty: &'static StateEventType, sk: &str| { + let state_fetch = |ty: &StateEventType, sk: &str| { let key = (ty.to_owned(), sk.into()); ready(auth_events.get(&key)) }; diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs index 28011a1b..37d47d47 100644 --- a/src/service/rooms/event_handler/resolve_state.rs +++ b/src/service/rooms/event_handler/resolve_state.rs @@ -63,7 +63,6 @@ pub async fn resolve_state( .multi_get_statekey_from_short(shortstatekeys) .zip(event_ids) .ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id))) - .map(|((ty, sk), id)| ((ty, sk.as_str().to_owned()), id)) .collect() }) .map(Ok::<_, Error>) diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs index 843b2af9..2eb6013a 100644 --- a/src/service/rooms/event_handler/state_at_incoming.rs +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -172,7 +172,6 @@ async fn state_at_incoming_fork( .short .get_statekey_from_short(*k) .map_ok(|(ty, sk)| ((ty, sk), id.clone())) - .map_ok(|((ty, sk), id)| ((ty, sk.as_str().to_owned()), id)) }) .ready_filter_map(Result::ok) .collect() diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs index f319ba48..385d2142 100644 --- a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -3,7 +3,7 @@ use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::In use conduwuit::{ debug, debug_info, err, implement, state_res, trace, utils::stream::{BroadbandExt, ReadyExt}, - warn, Err, EventTypeExt, PduEvent, Result, + warn, Err, EventTypeExt, PduEvent, Result, StateKey, }; use futures::{future::ready, FutureExt, StreamExt}; use ruma::{events::StateEventType, CanonicalJsonValue, RoomId, ServerName}; @@ -71,8 +71,8 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu( debug!("Performing auth check"); // 11. Check the auth of the event passes based on the state of the event let state_fetch_state = &state_at_incoming_event; - let state_fetch = |k: &'static StateEventType, s: String| async move { - let shortstatekey = self.services.short.get_shortstatekey(k, &s).await.ok()?; + let state_fetch = |k: StateEventType, s: StateKey| async move { + let shortstatekey = self.services.short.get_shortstatekey(&k, &s).await.ok()?; let event_id = state_fetch_state.get(&shortstatekey)?; self.services.timeline.get_pdu(event_id).await.ok() @@ -82,7 +82,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu( &room_version, &incoming_pdu, None, // TODO: third party invite - |k, s| state_fetch(k, s.to_owned()), + |ty, sk| state_fetch(ty.clone(), sk.into()), ) .await .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; @@ -104,7 +104,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu( ) .await?; - let state_fetch = |k: &'static StateEventType, s: &str| { + let state_fetch = |k: &StateEventType, s: &str| { let key = k.with_state_key(s); ready(auth_events.get(&key).cloned()) }; diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index d6154121..9d6ee982 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -747,7 +747,7 @@ impl Service { }; let auth_fetch = |k: &StateEventType, s: &str| { - let key = (k.clone(), s.to_owned()); + let key = (k.clone(), s.into()); ready(auth_events.get(&key)) };