From 0a9a9b3c92852cae269aaf2cb3894658b5e35a54 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 5 Feb 2025 12:22:22 +0000 Subject: [PATCH] larcen state-res from ruma --- Cargo.toml | 1 - src/api/client/membership.rs | 6 +- src/api/client/sync/v5.rs | 3 +- src/core/error/mod.rs | 2 +- src/core/mod.rs | 2 + src/core/pdu/event.rs | 2 +- src/core/state_res/LICENSE | 17 + src/core/state_res/error.rs | 23 + src/core/state_res/event_auth.rs | 1418 ++++++++++++++ src/core/state_res/mod.rs | 1644 +++++++++++++++++ src/core/state_res/outcomes.txt | 104 ++ src/core/state_res/power_levels.rs | 256 +++ src/core/state_res/room_version.rs | 149 ++ src/core/state_res/state_event.rs | 102 + src/core/state_res/state_res_bench.rs | 648 +++++++ src/core/state_res/test_utils.rs | 688 +++++++ src/service/rooms/event_handler/fetch_prev.rs | 11 +- .../rooms/event_handler/handle_outlier_pdu.rs | 6 +- src/service/rooms/event_handler/mod.rs | 6 +- .../rooms/event_handler/resolve_state.rs | 9 +- .../rooms/event_handler/state_at_incoming.rs | 4 +- .../event_handler/upgrade_outlier_pdu.rs | 10 +- src/service/rooms/state/mod.rs | 2 +- src/service/rooms/timeline/mod.rs | 2 +- 24 files changed, 5082 insertions(+), 33 deletions(-) create mode 100644 src/core/state_res/LICENSE create mode 100644 src/core/state_res/error.rs create mode 100644 src/core/state_res/event_auth.rs create mode 100644 src/core/state_res/mod.rs create mode 100644 src/core/state_res/outcomes.txt create mode 100644 src/core/state_res/power_levels.rs create mode 100644 src/core/state_res/room_version.rs create mode 100644 src/core/state_res/state_event.rs create mode 100644 src/core/state_res/state_res_bench.rs create mode 100644 src/core/state_res/test_utils.rs diff --git a/Cargo.toml b/Cargo.toml index b93877bd..d8f34544 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -355,7 +355,6 @@ features = [ "federation-api", "markdown", "push-gateway-api-c", - "state-res", "server-util", "unstable-exhaustive-types", "ring-compat", diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index 449d44d5..1045b014 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -12,7 +12,7 @@ use conduwuit::{ at, debug, debug_info, debug_warn, err, info, pdu::{gen_event_id_canonical_json, PduBuilder}, result::FlatOk, - trace, + state_res, trace, utils::{self, shuffle, IterStream, ReadyExt}, warn, Err, PduEvent, Result, }; @@ -40,8 +40,8 @@ use ruma::{ }, StateEventType, }, - state_res, CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId, - OwnedServerName, OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, + CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId, OwnedServerName, + OwnedUserId, RoomId, RoomVersionId, ServerName, UserId, }; use service::{ appservice::RegistrationInfo, diff --git a/src/api/client/sync/v5.rs b/src/api/client/sync/v5.rs index f8ee1047..63731688 100644 --- a/src/api/client/sync/v5.rs +++ b/src/api/client/sync/v5.rs @@ -11,7 +11,7 @@ use conduwuit::{ math::{ruma_from_usize, usize_from_ruma}, BoolExt, IterStream, ReadyExt, TryFutureExtExt, }, - warn, Error, Result, + warn, Error, Result, TypeStateKey, }; use futures::{FutureExt, StreamExt, TryFutureExt}; use ruma::{ @@ -24,7 +24,6 @@ use ruma::{ AnyRawAccountDataEvent, AnySyncEphemeralRoomEvent, StateEventType, TimelineEventType, }, serde::Raw, - state_res::TypeStateKey, uint, DeviceId, OwnedEventId, OwnedRoomId, RoomId, UInt, UserId, }; use service::{rooms::read_receipt::pack_receipts, PduCount}; diff --git a/src/core/error/mod.rs b/src/core/error/mod.rs index 88ac6d09..16613b7e 100644 --- a/src/core/error/mod.rs +++ b/src/core/error/mod.rs @@ -121,7 +121,7 @@ pub enum Error { #[error(transparent)] Signatures(#[from] ruma::signatures::Error), #[error(transparent)] - StateRes(#[from] ruma::state_res::Error), + StateRes(#[from] crate::state_res::Error), #[error("uiaa")] Uiaa(ruma::api::client::uiaa::UiaaInfo), diff --git a/src/core/mod.rs b/src/core/mod.rs index ee128628..cd56774a 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -8,6 +8,7 @@ pub mod metrics; pub mod mods; pub mod pdu; pub mod server; +pub mod state_res; pub mod utils; pub use ::arrayvec; @@ -22,6 +23,7 @@ pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; pub use pdu::{Event, PduBuilder, PduCount, PduEvent, PduId, RawPduId, StateKey}; pub use server::Server; +pub use state_res::{EventTypeExt, RoomVersion, StateMap, TypeStateKey}; pub use utils::{ctor, dtor, implement, result, result::Result}; pub use crate as conduwuit_core; diff --git a/src/core/pdu/event.rs b/src/core/pdu/event.rs index 6a92afe8..d5c0561e 100644 --- a/src/core/pdu/event.rs +++ b/src/core/pdu/event.rs @@ -1,8 +1,8 @@ -pub use ruma::state_res::Event; use ruma::{events::TimelineEventType, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId}; use serde_json::value::RawValue as RawJsonValue; use super::Pdu; +pub use crate::state_res::Event; impl Event for Pdu { type Id = OwnedEventId; diff --git a/src/core/state_res/LICENSE b/src/core/state_res/LICENSE new file mode 100644 index 00000000..c103a044 --- /dev/null +++ b/src/core/state_res/LICENSE @@ -0,0 +1,17 @@ +//! Permission is hereby granted, free of charge, to any person obtaining a copy +//! of this software and associated documentation files (the "Software"), to +//! deal in the Software without restriction, including without limitation the +//! rights to use, copy, modify, merge, publish, distribute, sublicense, and/or +//! sell copies of the Software, and to permit persons to whom the Software is +//! furnished to do so, subject to the following conditions: + +//! The above copyright notice and this permission notice shall be included in +//! all copies or substantial portions of the Software. + +//! THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +//! IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +//! FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +//! AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +//! LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING +//! FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS +//! IN THE SOFTWARE. diff --git a/src/core/state_res/error.rs b/src/core/state_res/error.rs new file mode 100644 index 00000000..7711d878 --- /dev/null +++ b/src/core/state_res/error.rs @@ -0,0 +1,23 @@ +use serde_json::Error as JsonError; +use thiserror::Error; + +/// Represents the various errors that arise when resolving state. +#[derive(Error, Debug)] +#[non_exhaustive] +pub enum Error { + /// A deserialization error. + #[error(transparent)] + SerdeJson(#[from] JsonError), + + /// The given option or version is unsupported. + #[error("Unsupported room version: {0}")] + Unsupported(String), + + /// The given event was not found. + #[error("Not found error: {0}")] + NotFound(String), + + /// Invalid fields in the given PDU. + #[error("Invalid PDU: {0}")] + InvalidPdu(String), +} diff --git a/src/core/state_res/event_auth.rs b/src/core/state_res/event_auth.rs new file mode 100644 index 00000000..72a0216c --- /dev/null +++ b/src/core/state_res/event_auth.rs @@ -0,0 +1,1418 @@ +use std::{borrow::Borrow, collections::BTreeSet}; + +use futures::{ + future::{join3, OptionFuture}, + Future, +}; +use ruma::{ + events::room::{ + create::RoomCreateEventContent, + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, ThirdPartyInvite}, + power_levels::RoomPowerLevelsEventContent, + third_party_invite::RoomThirdPartyInviteEventContent, + }, + int, + serde::{Base64, Raw}, + Int, OwnedUserId, RoomVersionId, UserId, +}; +use serde::{ + de::{Error as _, IgnoredAny}, + Deserialize, +}; +use serde_json::{from_str as from_json_str, value::RawValue as RawJsonValue}; +use tracing::{debug, error, instrument, trace, warn}; + +use super::{ + power_levels::{ + deserialize_power_levels, deserialize_power_levels_content_fields, + deserialize_power_levels_content_invite, deserialize_power_levels_content_redact, + }, + room_version::RoomVersion, + Error, Event, Result, StateEventType, TimelineEventType, +}; + +// FIXME: field extracting could be bundled for `content` +#[derive(Deserialize)] +struct GetMembership { + membership: MembershipState, +} + +#[derive(Deserialize)] +struct RoomMemberContentFields { + membership: Option>, + join_authorised_via_users_server: Option>, +} + +/// For the given event `kind` what are the relevant auth events that are needed +/// to authenticate this `content`. +/// +/// # Errors +/// +/// This function will return an error if the supplied `content` is not a JSON +/// object. +pub fn auth_types_for_event( + kind: &TimelineEventType, + sender: &UserId, + state_key: Option<&str>, + content: &RawJsonValue, +) -> serde_json::Result> { + if kind == &TimelineEventType::RoomCreate { + return Ok(vec![]); + } + + let mut auth_types = vec![ + (StateEventType::RoomPowerLevels, String::new()), + (StateEventType::RoomMember, sender.to_string()), + (StateEventType::RoomCreate, String::new()), + ]; + + if kind == &TimelineEventType::RoomMember { + #[derive(Deserialize)] + struct RoomMemberContentFields { + membership: Option>, + third_party_invite: Option>, + join_authorised_via_users_server: Option>, + } + + if let Some(state_key) = state_key { + let content: RoomMemberContentFields = from_json_str(content.get())?; + + if let Some(Ok(membership)) = content.membership.map(|m| m.deserialize()) { + if [MembershipState::Join, MembershipState::Invite, MembershipState::Knock] + .contains(&membership) + { + let key = (StateEventType::RoomJoinRules, String::new()); + if !auth_types.contains(&key) { + auth_types.push(key); + } + + if let Some(Ok(u)) = content + .join_authorised_via_users_server + .map(|m| m.deserialize()) + { + let key = (StateEventType::RoomMember, u.to_string()); + if !auth_types.contains(&key) { + auth_types.push(key); + } + } + } + + let key = (StateEventType::RoomMember, state_key.to_owned()); + if !auth_types.contains(&key) { + auth_types.push(key); + } + + if membership == MembershipState::Invite { + if let Some(Ok(t_id)) = content.third_party_invite.map(|t| t.deserialize()) { + let key = (StateEventType::RoomThirdPartyInvite, t_id.signed.token); + if !auth_types.contains(&key) { + auth_types.push(key); + } + } + } + } + } + } + + Ok(auth_types) +} + +/// Authenticate the incoming `event`. +/// +/// The steps of authentication are: +/// +/// * check that the event is being authenticated for the correct room +/// * then there are checks for specific event types +/// +/// The `fetch_state` closure should gather state from a state snapshot. We need +/// to know if the event passes auth against some state not a recursive +/// collection of auth_events fields. +#[instrument(level = "debug", skip_all, fields(event_id = incoming_event.event_id().borrow().as_str()))] +pub async fn auth_check( + room_version: &RoomVersion, + incoming_event: &Incoming, + current_third_party_invite: Option<&Incoming>, + fetch_state: F, +) -> Result +where + F: Fn(&'static StateEventType, &str) -> Fut, + Fut: Future> + Send, + Fetched: Event + Send, + Incoming: Event + Send, +{ + debug!( + "auth_check beginning for {} ({})", + incoming_event.event_id(), + incoming_event.event_type() + ); + + // [synapse] check that all the events are in the same room as `incoming_event` + + // [synapse] do_sig_check check the event has valid signatures for member events + + // TODO do_size_check is false when called by `iterative_auth_check` + // do_size_check is also mostly accomplished by ruma with the exception of + // checking event_type, state_key, and json are below a certain size (255 and + // 65_536 respectively) + + let sender = incoming_event.sender(); + + // Implementation of https://spec.matrix.org/latest/rooms/v1/#authorization-rules + // + // 1. If type is m.room.create: + if *incoming_event.event_type() == TimelineEventType::RoomCreate { + #[derive(Deserialize)] + struct RoomCreateContentFields { + room_version: Option>, + creator: Option>, + } + + debug!("start m.room.create check"); + + // If it has any previous events, reject + if incoming_event.prev_events().next().is_some() { + warn!("the room creation event had previous events"); + return Ok(false); + } + + // If the domain of the room_id does not match the domain of the sender, reject + let Some(room_id_server_name) = incoming_event.room_id().server_name() else { + warn!("room ID has no servername"); + return Ok(false); + }; + + if room_id_server_name != sender.server_name() { + warn!("servername of room ID does not match servername of sender"); + return Ok(false); + } + + // If content.room_version is present and is not a recognized version, reject + let content: RoomCreateContentFields = from_json_str(incoming_event.content().get())?; + if content + .room_version + .is_some_and(|v| v.deserialize().is_err()) + { + warn!("invalid room version found in m.room.create event"); + return Ok(false); + } + + if !room_version.use_room_create_sender { + // If content has no creator field, reject + if content.creator.is_none() { + warn!("no creator field found in m.room.create content"); + return Ok(false); + } + } + + debug!("m.room.create event was allowed"); + return Ok(true); + } + + /* + // TODO: In the past this code caused problems federating with synapse, maybe this has been + // resolved already. Needs testing. + // + // 2. Reject if auth_events + // a. auth_events cannot have duplicate keys since it's a BTree + // b. All entries are valid auth events according to spec + let expected_auth = auth_types_for_event( + incoming_event.kind, + sender, + incoming_event.state_key, + incoming_event.content().clone(), + ); + + dbg!(&expected_auth); + + for ev_key in auth_events.keys() { + // (b) + if !expected_auth.contains(ev_key) { + warn!("auth_events contained invalid auth event"); + return Ok(false); + } + } + */ + + let (room_create_event, power_levels_event, sender_member_event) = join3( + fetch_state(&StateEventType::RoomCreate, ""), + fetch_state(&StateEventType::RoomPowerLevels, ""), + fetch_state(&StateEventType::RoomMember, sender.as_str()), + ) + .await; + + let room_create_event = match room_create_event { + | None => { + warn!("no m.room.create event in auth chain"); + return Ok(false); + }, + | Some(e) => e, + }; + + // 3. If event does not have m.room.create in auth_events reject + if !incoming_event + .auth_events() + .any(|id| id.borrow() == room_create_event.event_id().borrow()) + { + warn!("no m.room.create event in auth events"); + return Ok(false); + } + + // If the create event content has the field m.federate set to false and the + // sender domain of the event does not match the sender domain of the create + // event, reject. + #[derive(Deserialize)] + struct RoomCreateContentFederate { + #[serde(rename = "m.federate", default = "ruma::serde::default_true")] + federate: bool, + } + let room_create_content: RoomCreateContentFederate = + from_json_str(room_create_event.content().get())?; + if !room_create_content.federate + && room_create_event.sender().server_name() != incoming_event.sender().server_name() + { + warn!( + "room is not federated and event's sender domain does not match create event's \ + sender domain" + ); + return Ok(false); + } + + // Only in some room versions 6 and below + if room_version.special_case_aliases_auth { + // 4. If type is m.room.aliases + if *incoming_event.event_type() == TimelineEventType::RoomAliases { + debug!("starting m.room.aliases check"); + + // If sender's domain doesn't matches state_key, reject + if incoming_event.state_key() != Some(sender.server_name().as_str()) { + warn!("state_key does not match sender"); + return Ok(false); + } + + debug!("m.room.aliases event was allowed"); + return Ok(true); + } + } + + // If type is m.room.member + if *incoming_event.event_type() == TimelineEventType::RoomMember { + debug!("starting m.room.member check"); + let state_key = match incoming_event.state_key() { + | None => { + warn!("no statekey in member event"); + return Ok(false); + }, + | Some(s) => s, + }; + + let content: RoomMemberContentFields = from_json_str(incoming_event.content().get())?; + if content + .membership + .as_ref() + .and_then(|m| m.deserialize().ok()) + .is_none() + { + warn!("no valid membership field found for m.room.member event content"); + return Ok(false); + } + + let target_user = + <&UserId>::try_from(state_key).map_err(|e| Error::InvalidPdu(format!("{e}")))?; + + let user_for_join_auth = content + .join_authorised_via_users_server + .as_ref() + .and_then(|u| u.deserialize().ok()); + + let user_for_join_auth_event: OptionFuture<_> = user_for_join_auth + .as_ref() + .map(|auth_user| fetch_state(&StateEventType::RoomMember, auth_user.as_str())) + .into(); + + let target_user_member_event = + fetch_state(&StateEventType::RoomMember, target_user.as_str()); + + let join_rules_event = fetch_state(&StateEventType::RoomJoinRules, ""); + + let (join_rules_event, target_user_member_event, user_for_join_auth_event) = + join3(join_rules_event, target_user_member_event, user_for_join_auth_event).await; + + let user_for_join_auth_membership = user_for_join_auth_event + .and_then(|mem| from_json_str::(mem?.content().get()).ok()) + .map_or(MembershipState::Leave, |mem| mem.membership); + + if !valid_membership_change( + room_version, + target_user, + target_user_member_event.as_ref(), + sender, + sender_member_event.as_ref(), + incoming_event, + current_third_party_invite, + power_levels_event.as_ref(), + join_rules_event.as_ref(), + user_for_join_auth.as_deref(), + &user_for_join_auth_membership, + room_create_event, + )? { + return Ok(false); + } + + debug!("m.room.member event was allowed"); + return Ok(true); + } + + // If the sender's current membership state is not join, reject + let sender_member_event = match sender_member_event { + | Some(mem) => mem, + | None => { + warn!("sender not found in room"); + return Ok(false); + }, + }; + + let sender_membership_event_content: RoomMemberContentFields = + from_json_str(sender_member_event.content().get())?; + let membership_state = sender_membership_event_content + .membership + .expect("we should test before that this field exists") + .deserialize()?; + + if !matches!(membership_state, MembershipState::Join) { + warn!("sender's membership is not join"); + return Ok(false); + } + + // If type is m.room.third_party_invite + let sender_power_level = if let Some(pl) = &power_levels_event { + let content = deserialize_power_levels_content_fields(pl.content().get(), room_version)?; + if let Some(level) = content.get_user_power(sender) { + *level + } else { + content.users_default + } + } else { + // If no power level event found the creator gets 100 everyone else gets 0 + let is_creator = if room_version.use_room_create_sender { + room_create_event.sender() == sender + } else { + #[allow(deprecated)] + from_json_str::(room_create_event.content().get()) + .is_ok_and(|create| create.creator.unwrap() == *sender) + }; + + if is_creator { + int!(100) + } else { + int!(0) + } + }; + + // Allow if and only if sender's current power level is greater than + // or equal to the invite level + if *incoming_event.event_type() == TimelineEventType::RoomThirdPartyInvite { + let invite_level = match &power_levels_event { + | Some(power_levels) => + deserialize_power_levels_content_invite( + power_levels.content().get(), + room_version, + )? + .invite, + | None => int!(0), + }; + + if sender_power_level < invite_level { + warn!("sender's cannot send invites in this room"); + return Ok(false); + } + + debug!("m.room.third_party_invite event was allowed"); + return Ok(true); + } + + // If the event type's required power level is greater than the sender's power + // level, reject If the event has a state_key that starts with an @ and does + // not match the sender, reject. + if !can_send_event(incoming_event, power_levels_event.as_ref(), sender_power_level) { + warn!("user cannot send event"); + return Ok(false); + } + + // If type is m.room.power_levels + if *incoming_event.event_type() == TimelineEventType::RoomPowerLevels { + debug!("starting m.room.power_levels check"); + + if let Some(required_pwr_lvl) = check_power_levels( + room_version, + incoming_event, + power_levels_event.as_ref(), + sender_power_level, + ) { + if !required_pwr_lvl { + warn!("m.room.power_levels was not allowed"); + return Ok(false); + } + } else { + warn!("m.room.power_levels was not allowed"); + return Ok(false); + } + debug!("m.room.power_levels event allowed"); + } + + // Room version 3: Redaction events are always accepted (provided the event is + // allowed by `events` and `events_default` in the power levels). However, + // servers should not apply or send redaction's to clients until both the + // redaction event and original event have been seen, and are valid. Servers + // should only apply redaction's to events where the sender's domains match, or + // the sender of the redaction has the appropriate permissions per the + // power levels. + + if room_version.extra_redaction_checks + && *incoming_event.event_type() == TimelineEventType::RoomRedaction + { + let redact_level = match power_levels_event { + | Some(pl) => + deserialize_power_levels_content_redact(pl.content().get(), room_version)?.redact, + | None => int!(50), + }; + + if !check_redaction(room_version, incoming_event, sender_power_level, redact_level)? { + return Ok(false); + } + } + + debug!("allowing event passed all checks"); + Ok(true) +} + +// TODO deserializing the member, power, join_rules event contents is done in +// conduit just before this is called. Could they be passed in? +/// Does the user who sent this member event have required power levels to do +/// so. +/// +/// * `user` - Information about the membership event and user making the +/// request. +/// * `auth_events` - The set of auth events that relate to a membership event. +/// +/// This is generated by calling `auth_types_for_event` with the membership +/// event and the current State. +#[allow(clippy::too_many_arguments)] +fn valid_membership_change( + room_version: &RoomVersion, + target_user: &UserId, + target_user_membership_event: Option, + sender: &UserId, + sender_membership_event: Option, + current_event: impl Event, + current_third_party_invite: Option, + power_levels_event: Option, + join_rules_event: Option, + user_for_join_auth: Option<&UserId>, + user_for_join_auth_membership: &MembershipState, + create_room: impl Event, +) -> Result { + #[derive(Deserialize)] + struct GetThirdPartyInvite { + third_party_invite: Option>, + } + let content = current_event.content(); + + let target_membership = from_json_str::(content.get())?.membership; + let third_party_invite = + from_json_str::(content.get())?.third_party_invite; + + let sender_membership = match &sender_membership_event { + | Some(pdu) => from_json_str::(pdu.content().get())?.membership, + | None => MembershipState::Leave, + }; + let sender_is_joined = sender_membership == MembershipState::Join; + + let target_user_current_membership = match &target_user_membership_event { + | Some(pdu) => from_json_str::(pdu.content().get())?.membership, + | None => MembershipState::Leave, + }; + + let power_levels: RoomPowerLevelsEventContent = match &power_levels_event { + | Some(ev) => from_json_str(ev.content().get())?, + | None => RoomPowerLevelsEventContent::default(), + }; + + let sender_power = power_levels + .users + .get(sender) + .or_else(|| sender_is_joined.then_some(&power_levels.users_default)); + + let target_power = power_levels.users.get(target_user).or_else(|| { + (target_membership == MembershipState::Join).then_some(&power_levels.users_default) + }); + + let mut join_rules = JoinRule::Invite; + if let Some(jr) = &join_rules_event { + join_rules = from_json_str::(jr.content().get())?.join_rule; + } + + let power_levels_event_id = power_levels_event.as_ref().map(Event::event_id); + let sender_membership_event_id = sender_membership_event.as_ref().map(Event::event_id); + let target_user_membership_event_id = + target_user_membership_event.as_ref().map(Event::event_id); + + let user_for_join_auth_is_valid = if let Some(user_for_join_auth) = user_for_join_auth { + // Is the authorised user allowed to invite users into this room + let (auth_user_pl, invite_level) = if let Some(pl) = &power_levels_event { + // TODO Refactor all powerlevel parsing + let invite = + deserialize_power_levels_content_invite(pl.content().get(), room_version)?.invite; + + let content = + deserialize_power_levels_content_fields(pl.content().get(), room_version)?; + let user_pl = if let Some(level) = content.get_user_power(user_for_join_auth) { + *level + } else { + content.users_default + }; + + (user_pl, invite) + } else { + (int!(0), int!(0)) + }; + (user_for_join_auth_membership == &MembershipState::Join) + && (auth_user_pl >= invite_level) + } else { + // No auth user was given + false + }; + + Ok(match target_membership { + | MembershipState::Join => { + // 1. If the only previous event is an m.room.create and the state_key is the + // creator, + // allow + let mut prev_events = current_event.prev_events(); + + let prev_event_is_create_event = prev_events + .next() + .is_some_and(|event_id| event_id.borrow() == create_room.event_id().borrow()); + let no_more_prev_events = prev_events.next().is_none(); + + if prev_event_is_create_event && no_more_prev_events { + let is_creator = if room_version.use_room_create_sender { + let creator = create_room.sender(); + + creator == sender && creator == target_user + } else { + #[allow(deprecated)] + let creator = from_json_str::(create_room.content().get())? + .creator + .ok_or_else(|| serde_json::Error::missing_field("creator"))?; + + creator == sender && creator == target_user + }; + + if is_creator { + return Ok(true); + } + } + + if sender != target_user { + // If the sender does not match state_key, reject. + warn!("Can't make other user join"); + false + } else if target_user_current_membership == MembershipState::Ban { + // If the sender is banned, reject. + warn!(?target_user_membership_event_id, "Banned user can't join"); + false + } else if (join_rules == JoinRule::Invite + || room_version.allow_knocking && join_rules == JoinRule::Knock) + // If the join_rule is invite then allow if membership state is invite or join + && (target_user_current_membership == MembershipState::Join + || target_user_current_membership == MembershipState::Invite) + { + true + } else if room_version.restricted_join_rules + && matches!(join_rules, JoinRule::Restricted(_)) + || room_version.knock_restricted_join_rule + && matches!(join_rules, JoinRule::KnockRestricted(_)) + { + // If the join_rule is restricted or knock_restricted + if matches!( + target_user_current_membership, + MembershipState::Invite | MembershipState::Join + ) { + // If membership state is join or invite, allow. + true + } else { + // If the join_authorised_via_users_server key in content is not a user with + // sufficient permission to invite other users, reject. + // Otherwise, allow. + user_for_join_auth_is_valid + } + } else { + // If the join_rule is public, allow. + // Otherwise, reject. + join_rules == JoinRule::Public + } + }, + | MembershipState::Invite => { + // If content has third_party_invite key + if let Some(tp_id) = third_party_invite.and_then(|i| i.deserialize().ok()) { + if target_user_current_membership == MembershipState::Ban { + warn!(?target_user_membership_event_id, "Can't invite banned user"); + false + } else { + let allow = verify_third_party_invite( + Some(target_user), + sender, + &tp_id, + current_third_party_invite, + ); + if !allow { + warn!("Third party invite invalid"); + } + allow + } + } else if !sender_is_joined + || target_user_current_membership == MembershipState::Join + || target_user_current_membership == MembershipState::Ban + { + warn!( + ?target_user_membership_event_id, + ?sender_membership_event_id, + "Can't invite user if sender not joined or the user is currently joined or \ + banned", + ); + false + } else { + let allow = sender_power + .filter(|&p| p >= &power_levels.invite) + .is_some(); + if !allow { + warn!( + ?target_user_membership_event_id, + ?power_levels_event_id, + "User does not have enough power to invite", + ); + } + allow + } + }, + | MembershipState::Leave => + if sender == target_user { + let allow = target_user_current_membership == MembershipState::Join + || target_user_current_membership == MembershipState::Invite + || target_user_current_membership == MembershipState::Knock; + if !allow { + warn!( + ?target_user_membership_event_id, + ?target_user_current_membership, + "Can't leave if sender is not already invited, knocked, or joined" + ); + } + allow + } else if !sender_is_joined + || target_user_current_membership == MembershipState::Ban + && sender_power.filter(|&p| p < &power_levels.ban).is_some() + { + warn!( + ?target_user_membership_event_id, + ?sender_membership_event_id, + "Can't kick if sender not joined or user is already banned", + ); + false + } else { + let allow = sender_power.filter(|&p| p >= &power_levels.kick).is_some() + && target_power < sender_power; + if !allow { + warn!( + ?target_user_membership_event_id, + ?power_levels_event_id, + "User does not have enough power to kick", + ); + } + allow + }, + | MembershipState::Ban => + if !sender_is_joined { + warn!(?sender_membership_event_id, "Can't ban user if sender is not joined"); + false + } else { + let allow = sender_power.filter(|&p| p >= &power_levels.ban).is_some() + && target_power < sender_power; + if !allow { + warn!( + ?target_user_membership_event_id, + ?power_levels_event_id, + "User does not have enough power to ban", + ); + } + allow + }, + | MembershipState::Knock if room_version.allow_knocking => { + // 1. If the `join_rule` is anything other than `knock` or `knock_restricted`, + // reject. + if !matches!(join_rules, JoinRule::KnockRestricted(_) | JoinRule::Knock) { + warn!( + "Join rule is not set to knock or knock_restricted, knocking is not allowed" + ); + false + } else if matches!(join_rules, JoinRule::KnockRestricted(_)) + && !room_version.knock_restricted_join_rule + { + // 2. If the `join_rule` is `knock_restricted`, but the room does not support + // `knock_restricted`, reject. + warn!( + "Join rule is set to knock_restricted but room version does not support \ + knock_restricted, knocking is not allowed" + ); + false + } else if sender != target_user { + // 3. If `sender` does not match `state_key`, reject. + warn!( + ?sender, + ?target_user, + "Can't make another user knock, sender did not match target" + ); + false + } else if matches!( + sender_membership, + MembershipState::Ban | MembershipState::Invite | MembershipState::Join + ) { + // 4. If the `sender`'s current membership is not `ban`, `invite`, or `join`, + // allow. + // 5. Otherwise, reject. + warn!( + ?target_user_membership_event_id, + "Knocking with a membership state of ban, invite or join is invalid", + ); + false + } else { + true + } + }, + | _ => { + warn!("Unknown membership transition"); + false + }, + }) +} + +/// Is the user allowed to send a specific event based on the rooms power +/// levels. +/// +/// Does the event have the correct userId as its state_key if it's not the "" +/// state_key. +fn can_send_event(event: impl Event, ple: Option, user_level: Int) -> bool { + let event_type_power_level = get_send_level(event.event_type(), event.state_key(), ple); + + debug!( + required_level = i64::from(event_type_power_level), + user_level = i64::from(user_level), + state_key = ?event.state_key(), + "permissions factors", + ); + + if user_level < event_type_power_level { + return false; + } + + if event.state_key().is_some_and(|k| k.starts_with('@')) + && event.state_key() != Some(event.sender().as_str()) + { + return false; // permission required to post in this room + } + + true +} + +/// Confirm that the event sender has the required power levels. +fn check_power_levels( + room_version: &RoomVersion, + power_event: impl Event, + previous_power_event: Option, + user_level: Int, +) -> Option { + match power_event.state_key() { + | Some("") => {}, + | Some(key) => { + error!(state_key = key, "m.room.power_levels event has non-empty state key"); + return None; + }, + | None => { + error!("check_power_levels requires an m.room.power_levels *state* event argument"); + return None; + }, + } + + // - If any of the keys users_default, events_default, state_default, ban, + // redact, kick, or invite in content are present and not an integer, reject. + // - If either of the keys events or notifications in content are present and + // not a dictionary with values that are integers, reject. + // - If users key in content is not a dictionary with keys that are valid user + // IDs with values that are integers, reject. + let user_content: RoomPowerLevelsEventContent = + deserialize_power_levels(power_event.content().get(), room_version)?; + + // Validation of users is done in Ruma, synapse for loops validating user_ids + // and integers here + debug!("validation of power event finished"); + + let current_state = match previous_power_event { + | Some(current_state) => current_state, + // If there is no previous m.room.power_levels event in the room, allow + | None => return Some(true), + }; + + let current_content: RoomPowerLevelsEventContent = + deserialize_power_levels(current_state.content().get(), room_version)?; + + let mut user_levels_to_check = BTreeSet::new(); + let old_list = ¤t_content.users; + let user_list = &user_content.users; + for user in old_list.keys().chain(user_list.keys()) { + let user: &UserId = user; + user_levels_to_check.insert(user); + } + + trace!(set = ?user_levels_to_check, "user levels to check"); + + let mut event_levels_to_check = BTreeSet::new(); + let old_list = ¤t_content.events; + let new_list = &user_content.events; + for ev_id in old_list.keys().chain(new_list.keys()) { + event_levels_to_check.insert(ev_id); + } + + trace!(set = ?event_levels_to_check, "event levels to check"); + + let old_state = ¤t_content; + let new_state = &user_content; + + // synapse does not have to split up these checks since we can't combine UserIds + // and EventTypes we do 2 loops + + // UserId loop + for user in user_levels_to_check { + let old_level = old_state.users.get(user); + let new_level = new_state.users.get(user); + if old_level.is_some() && new_level.is_some() && old_level == new_level { + continue; + } + + // If the current value is equal to the sender's current power level, reject + if user != power_event.sender() && old_level == Some(&user_level) { + warn!("m.room.power_level cannot remove ops == to own"); + return Some(false); // cannot remove ops level == to own + } + + // If the current value is higher than the sender's current power level, reject + // If the new value is higher than the sender's current power level, reject + let old_level_too_big = old_level > Some(&user_level); + let new_level_too_big = new_level > Some(&user_level); + if old_level_too_big || new_level_too_big { + warn!("m.room.power_level failed to add ops > than own"); + return Some(false); // cannot add ops greater than own + } + } + + // EventType loop + for ev_type in event_levels_to_check { + let old_level = old_state.events.get(ev_type); + let new_level = new_state.events.get(ev_type); + if old_level.is_some() && new_level.is_some() && old_level == new_level { + continue; + } + + // If the current value is higher than the sender's current power level, reject + // If the new value is higher than the sender's current power level, reject + let old_level_too_big = old_level > Some(&user_level); + let new_level_too_big = new_level > Some(&user_level); + if old_level_too_big || new_level_too_big { + warn!("m.room.power_level failed to add ops > than own"); + return Some(false); // cannot add ops greater than own + } + } + + // Notifications, currently there is only @room + if room_version.limit_notifications_power_levels { + let old_level = old_state.notifications.room; + let new_level = new_state.notifications.room; + if old_level != new_level { + // If the current value is higher than the sender's current power level, reject + // If the new value is higher than the sender's current power level, reject + let old_level_too_big = old_level > user_level; + let new_level_too_big = new_level > user_level; + if old_level_too_big || new_level_too_big { + warn!("m.room.power_level failed to add ops > than own"); + return Some(false); // cannot add ops greater than own + } + } + } + + let levels = [ + "users_default", + "events_default", + "state_default", + "ban", + "redact", + "kick", + "invite", + ]; + let old_state = serde_json::to_value(old_state).unwrap(); + let new_state = serde_json::to_value(new_state).unwrap(); + for lvl_name in &levels { + if let Some((old_lvl, new_lvl)) = get_deserialize_levels(&old_state, &new_state, lvl_name) + { + let old_level_too_big = old_lvl > user_level; + let new_level_too_big = new_lvl > user_level; + + if old_level_too_big || new_level_too_big { + warn!("cannot add ops > than own"); + return Some(false); + } + } + } + + Some(true) +} + +fn get_deserialize_levels( + old: &serde_json::Value, + new: &serde_json::Value, + name: &str, +) -> Option<(Int, Int)> { + Some(( + serde_json::from_value(old.get(name)?.clone()).ok()?, + serde_json::from_value(new.get(name)?.clone()).ok()?, + )) +} + +/// Does the event redacting come from a user with enough power to redact the +/// given event. +fn check_redaction( + _room_version: &RoomVersion, + redaction_event: impl Event, + user_level: Int, + redact_level: Int, +) -> Result { + if user_level >= redact_level { + debug!("redaction allowed via power levels"); + return Ok(true); + } + + // If the domain of the event_id of the event being redacted is the same as the + // domain of the event_id of the m.room.redaction, allow + if redaction_event.event_id().borrow().server_name() + == redaction_event + .redacts() + .as_ref() + .and_then(|&id| id.borrow().server_name()) + { + debug!("redaction event allowed via room version 1 rules"); + return Ok(true); + } + + Ok(false) +} + +/// Helper function to fetch the power level needed to send an event of type +/// `e_type` based on the rooms "m.room.power_level" event. +fn get_send_level( + e_type: &TimelineEventType, + state_key: Option<&str>, + power_lvl: Option, +) -> Int { + power_lvl + .and_then(|ple| { + from_json_str::(ple.content().get()) + .map(|content| { + content.events.get(e_type).copied().unwrap_or_else(|| { + if state_key.is_some() { + content.state_default + } else { + content.events_default + } + }) + }) + .ok() + }) + .unwrap_or_else(|| if state_key.is_some() { int!(50) } else { int!(0) }) +} + +fn verify_third_party_invite( + target_user: Option<&UserId>, + sender: &UserId, + tp_id: &ThirdPartyInvite, + current_third_party_invite: Option, +) -> bool { + // 1. Check for user being banned happens before this is called + // checking for mxid and token keys is done by ruma when deserializing + + // The state key must match the invitee + if target_user != Some(&tp_id.signed.mxid) { + return false; + } + + // If there is no m.room.third_party_invite event in the current room state with + // state_key matching token, reject + let current_tpid = match current_third_party_invite { + | Some(id) => id, + | None => return false, + }; + + if current_tpid.state_key() != Some(&tp_id.signed.token) { + return false; + } + + if sender != current_tpid.sender() { + return false; + } + + // If any signature in signed matches any public key in the + // m.room.third_party_invite event, allow + let tpid_ev = + match from_json_str::(current_tpid.content().get()) { + | Ok(ev) => ev, + | Err(_) => return false, + }; + + let decoded_invite_token = match Base64::parse(&tp_id.signed.token) { + | Ok(tok) => tok, + // FIXME: Log a warning? + | Err(_) => return false, + }; + + // A list of public keys in the public_keys field + for key in tpid_ev.public_keys.unwrap_or_default() { + if key.public_key == decoded_invite_token { + return true; + } + } + + // A single public key in the public_key field + tpid_ev.public_key == decoded_invite_token +} + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use ruma_events::{ + room::{ + join_rules::{ + AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership, + }, + member::{MembershipState, RoomMemberEventContent}, + }, + StateEventType, TimelineEventType, + }; + use serde_json::value::to_raw_value as to_raw_json_value; + + use crate::{ + event_auth::valid_membership_change, + test_utils::{ + alice, charlie, ella, event_id, member_content_ban, member_content_join, room_id, + to_pdu_event, PduEvent, INITIAL_EVENTS, INITIAL_EVENTS_CREATE_ROOM, + }, + Event, EventTypeExt, RoomVersion, StateMap, + }; + + #[test] + fn test_ban_pass() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + let events = INITIAL_EVENTS(); + + let auth_events = events + .values() + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) + .collect::>(); + + let requester = to_pdu_event( + "HELLO", + alice(), + TimelineEventType::RoomMember, + Some(charlie().as_str()), + member_content_ban(), + &[], + &["IMC"], + ); + + let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned(); + let target_user = charlie(); + let sender = alice(); + + assert!(valid_membership_change( + &RoomVersion::V6, + target_user, + fetch_state(StateEventType::RoomMember, target_user.to_string()), + sender, + fetch_state(StateEventType::RoomMember, sender.to_string()), + &requester, + None::, + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None, + &MembershipState::Leave, + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + ) + .unwrap()); + } + + #[test] + fn test_join_non_creator() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + let events = INITIAL_EVENTS_CREATE_ROOM(); + + let auth_events = events + .values() + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) + .collect::>(); + + let requester = to_pdu_event( + "HELLO", + charlie(), + TimelineEventType::RoomMember, + Some(charlie().as_str()), + member_content_join(), + &["CREATE"], + &["CREATE"], + ); + + let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned(); + let target_user = charlie(); + let sender = charlie(); + + assert!(!valid_membership_change( + &RoomVersion::V6, + target_user, + fetch_state(StateEventType::RoomMember, target_user.to_string()), + sender, + fetch_state(StateEventType::RoomMember, sender.to_string()), + &requester, + None::, + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None, + &MembershipState::Leave, + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + ) + .unwrap()); + } + + #[test] + fn test_join_creator() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + let events = INITIAL_EVENTS_CREATE_ROOM(); + + let auth_events = events + .values() + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) + .collect::>(); + + let requester = to_pdu_event( + "HELLO", + alice(), + TimelineEventType::RoomMember, + Some(alice().as_str()), + member_content_join(), + &["CREATE"], + &["CREATE"], + ); + + let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned(); + let target_user = alice(); + let sender = alice(); + + assert!(valid_membership_change( + &RoomVersion::V6, + target_user, + fetch_state(StateEventType::RoomMember, target_user.to_string()), + sender, + fetch_state(StateEventType::RoomMember, sender.to_string()), + &requester, + None::, + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None, + &MembershipState::Leave, + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + ) + .unwrap()); + } + + #[test] + fn test_ban_fail() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + let events = INITIAL_EVENTS(); + + let auth_events = events + .values() + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) + .collect::>(); + + let requester = to_pdu_event( + "HELLO", + charlie(), + TimelineEventType::RoomMember, + Some(alice().as_str()), + member_content_ban(), + &[], + &["IMC"], + ); + + let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned(); + let target_user = alice(); + let sender = charlie(); + + assert!(!valid_membership_change( + &RoomVersion::V6, + target_user, + fetch_state(StateEventType::RoomMember, target_user.to_string()), + sender, + fetch_state(StateEventType::RoomMember, sender.to_string()), + &requester, + None::, + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None, + &MembershipState::Leave, + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + ) + .unwrap()); + } + + #[test] + fn test_restricted_join_rule() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + let mut events = INITIAL_EVENTS(); + *events.get_mut(&event_id("IJR")).unwrap() = to_pdu_event( + "IJR", + alice(), + TimelineEventType::RoomJoinRules, + Some(""), + to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Restricted( + Restricted::new(vec![AllowRule::RoomMembership(RoomMembership::new( + room_id().to_owned(), + ))]), + ))) + .unwrap(), + &["CREATE", "IMA", "IPOWER"], + &["IPOWER"], + ); + + let mut member = RoomMemberEventContent::new(MembershipState::Join); + member.join_authorized_via_users_server = Some(alice().to_owned()); + + let auth_events = events + .values() + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) + .collect::>(); + + let requester = to_pdu_event( + "HELLO", + ella(), + TimelineEventType::RoomMember, + Some(ella().as_str()), + to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Join)).unwrap(), + &["CREATE", "IJR", "IPOWER", "new"], + &["new"], + ); + + let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned(); + let target_user = ella(); + let sender = ella(); + + assert!(valid_membership_change( + &RoomVersion::V9, + target_user, + fetch_state(StateEventType::RoomMember, target_user.to_string()), + sender, + fetch_state(StateEventType::RoomMember, sender.to_string()), + &requester, + None::, + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + Some(alice()), + &MembershipState::Join, + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + ) + .unwrap()); + + assert!(!valid_membership_change( + &RoomVersion::V9, + target_user, + fetch_state(StateEventType::RoomMember, target_user.to_string()), + sender, + fetch_state(StateEventType::RoomMember, sender.to_string()), + &requester, + None::, + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + Some(ella()), + &MembershipState::Leave, + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + ) + .unwrap()); + } + + #[test] + fn test_knock() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + let mut events = INITIAL_EVENTS(); + *events.get_mut(&event_id("IJR")).unwrap() = to_pdu_event( + "IJR", + alice(), + TimelineEventType::RoomJoinRules, + Some(""), + to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Knock)).unwrap(), + &["CREATE", "IMA", "IPOWER"], + &["IPOWER"], + ); + + let auth_events = events + .values() + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), Arc::clone(ev))) + .collect::>(); + + let requester = to_pdu_event( + "HELLO", + ella(), + TimelineEventType::RoomMember, + Some(ella().as_str()), + to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Knock)).unwrap(), + &[], + &["IMC"], + ); + + let fetch_state = |ty, key| auth_events.get(&(ty, key)).cloned(); + let target_user = ella(); + let sender = ella(); + + assert!(valid_membership_change( + &RoomVersion::V7, + target_user, + fetch_state(StateEventType::RoomMember, target_user.to_string()), + sender, + fetch_state(StateEventType::RoomMember, sender.to_string()), + &requester, + None::, + fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), + fetch_state(StateEventType::RoomJoinRules, "".to_owned()), + None, + &MembershipState::Leave, + fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), + ) + .unwrap()); + } +} diff --git a/src/core/state_res/mod.rs b/src/core/state_res/mod.rs new file mode 100644 index 00000000..e4054377 --- /dev/null +++ b/src/core/state_res/mod.rs @@ -0,0 +1,1644 @@ +pub(crate) mod error; +pub mod event_auth; +mod power_levels; +mod room_version; +mod state_event; + +#[cfg(test)] +mod test_utils; + +use std::{ + borrow::Borrow, + cmp::{Ordering, Reverse}, + collections::{BinaryHeap, HashMap, HashSet}, + fmt::Debug, + hash::Hash, +}; + +use futures::{future, stream, Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; +use ruma::{ + events::{ + room::member::{MembershipState, RoomMemberEventContent}, + StateEventType, TimelineEventType, + }, + int, EventId, Int, MilliSecondsSinceUnixEpoch, RoomVersionId, +}; +use serde_json::from_str as from_json_str; + +pub(crate) use self::error::Error; +use self::power_levels::PowerLevelsContentFields; +pub use self::{ + event_auth::{auth_check, auth_types_for_event}, + room_version::RoomVersion, + state_event::Event, +}; +use crate::{debug, trace, warn}; + +/// A mapping of event type and state_key to some value `T`, usually an +/// `EventId`. +pub type StateMap = HashMap; +pub type StateMapItem = (TypeStateKey, T); +pub type TypeStateKey = (StateEventType, String); + +type Result = crate::Result; + +/// Resolve sets of state events as they come in. +/// +/// Internally `StateResolution` builds a graph and an auth chain to allow for +/// state conflict resolution. +/// +/// ## Arguments +/// +/// * `state_sets` - The incoming state to resolve. Each `StateMap` represents a +/// possible fork in the state of a room. +/// +/// * `auth_chain_sets` - The full recursive set of `auth_events` for each event +/// in the `state_sets`. +/// +/// * `event_fetch` - Any event not found in the `event_map` will defer to this +/// closure to find the event. +/// +/// * `parallel_fetches` - The number of asynchronous fetch requests in-flight +/// for any given operation. +/// +/// ## Invariants +/// +/// The caller of `resolve` must ensure that all the events are from the same +/// room. Although this function takes a `RoomId` it does not check that each +/// event is part of the same room. +//#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets, +//#[tracing::instrument(level event_fetch))] +pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>( + room_version: &RoomVersionId, + state_sets: impl IntoIterator + Send, + auth_chain_sets: &'a [HashSet], + event_fetch: &Fetch, + event_exists: &Exists, + parallel_fetches: usize, +) -> Result> +where + Fetch: Fn(E::Id) -> FetchFut + Sync, + FetchFut: Future> + Send, + Exists: Fn(E::Id) -> ExistsFut + Sync, + ExistsFut: Future + Send, + SetIter: Iterator> + Clone + Send, + E: Event + Clone + Send + Sync, + E::Id: Borrow + Send + Sync, + for<'b> &'b E: Send, +{ + debug!("State resolution starting"); + + // Split non-conflicting and conflicting state + let (clean, conflicting) = separate(state_sets.into_iter()); + + debug!(count = clean.len(), "non-conflicting events"); + trace!(map = ?clean, "non-conflicting events"); + + if conflicting.is_empty() { + debug!("no conflicting state found"); + return Ok(clean); + } + + debug!(count = conflicting.len(), "conflicting events"); + trace!(map = ?conflicting, "conflicting events"); + + let auth_chain_diff = + get_auth_chain_diff(auth_chain_sets).chain(conflicting.into_values().flatten()); + + // `all_conflicted` contains unique items + // synapse says `full_set = {eid for eid in full_conflicted_set if eid in + // event_map}` + let all_conflicted: HashSet<_> = stream::iter(auth_chain_diff) + // Don't honor events we cannot "verify" + .map(|id| event_exists(id.clone()).map(move |exists| (id, exists))) + .buffer_unordered(parallel_fetches) + .filter_map(|(id, exists)| future::ready(exists.then_some(id))) + .collect() + .boxed() + .await; + + debug!(count = all_conflicted.len(), "full conflicted set"); + trace!(set = ?all_conflicted, "full conflicted set"); + + // We used to check that all events are events from the correct room + // this is now a check the caller of `resolve` must make. + + // Get only the control events with a state_key: "" or ban/kick event (sender != + // state_key) + let control_events: Vec<_> = stream::iter(all_conflicted.iter()) + .map(|id| is_power_event_id(id, &event_fetch).map(move |is| (id, is))) + .buffer_unordered(parallel_fetches) + .filter_map(|(id, is)| future::ready(is.then_some(id.clone()))) + .collect() + .boxed() + .await; + + // Sort the control events based on power_level/clock/event_id and + // outgoing/incoming edges + let sorted_control_levels = reverse_topological_power_sort( + control_events, + &all_conflicted, + &event_fetch, + parallel_fetches, + ) + .boxed() + .await?; + + debug!(count = sorted_control_levels.len(), "power events"); + trace!(list = ?sorted_control_levels, "sorted power events"); + + let room_version = RoomVersion::new(room_version)?; + // Sequentially auth check each control event. + let resolved_control = iterative_auth_check( + &room_version, + sorted_control_levels.iter(), + clean.clone(), + &event_fetch, + parallel_fetches, + ) + .boxed() + .await?; + + debug!(count = resolved_control.len(), "resolved power events"); + trace!(map = ?resolved_control, "resolved power events"); + + // At this point the control_events have been resolved we now have to + // sort the remaining events using the mainline of the resolved power level. + let deduped_power_ev = sorted_control_levels.into_iter().collect::>(); + + // This removes the control events that passed auth and more importantly those + // that failed auth + let events_to_resolve = all_conflicted + .iter() + .filter(|&id| !deduped_power_ev.contains(id.borrow())) + .cloned() + .collect::>(); + + debug!(count = events_to_resolve.len(), "events left to resolve"); + trace!(list = ?events_to_resolve, "events left to resolve"); + + // This "epochs" power level event + let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, String::new())); + + debug!(event_id = ?power_event, "power event"); + + let sorted_left_events = + mainline_sort(&events_to_resolve, power_event.cloned(), &event_fetch, parallel_fetches) + .boxed() + .await?; + + trace!(list = ?sorted_left_events, "events left, sorted"); + + let mut resolved_state = iterative_auth_check( + &room_version, + sorted_left_events.iter(), + resolved_control, // The control events are added to the final resolved state + &event_fetch, + parallel_fetches, + ) + .boxed() + .await?; + + // Add unconflicted state to the resolved state + // We priorities the unconflicting state + resolved_state.extend(clean); + + debug!("state resolution finished"); + + Ok(resolved_state) +} + +/// Split the events that have no conflicts from those that are conflicting. +/// +/// The return tuple looks like `(unconflicted, conflicted)`. +/// +/// State is determined to be conflicting if for the given key (StateEventType, +/// StateKey) there is not exactly one event ID. This includes missing events, +/// if one state_set includes an event that none of the other have this is a +/// conflicting event. +fn separate<'a, Id>( + state_sets_iter: impl Iterator>, +) -> (StateMap, StateMap>) +where + Id: Clone + Eq + Hash + 'a, +{ + let mut state_set_count = 0_usize; + let mut occurrences = HashMap::<_, HashMap<_, _>>::new(); + + let state_sets_iter = state_sets_iter.inspect(|_| state_set_count += 1); + for (k, v) in state_sets_iter.flatten() { + occurrences + .entry(k) + .or_default() + .entry(v) + .and_modify(|x| *x += 1) + .or_insert(1); + } + + let mut unconflicted_state = StateMap::new(); + let mut conflicted_state = StateMap::new(); + + for (k, v) in occurrences { + for (id, occurrence_count) in v { + if occurrence_count == state_set_count { + unconflicted_state.insert((k.0.clone(), k.1.clone()), id.clone()); + } else { + conflicted_state + .entry((k.0.clone(), k.1.clone())) + .and_modify(|x: &mut Vec<_>| x.push(id.clone())) + .or_insert(vec![id.clone()]); + } + } + } + + (unconflicted_state, conflicted_state) +} + +/// Returns a Vec of deduped EventIds that appear in some chains but not others. +fn get_auth_chain_diff(auth_chain_sets: &[HashSet]) -> impl Iterator + Send +where + Id: Clone + Eq + Hash + Send, +{ + let num_sets = auth_chain_sets.len(); + let mut id_counts: HashMap = HashMap::new(); + for id in auth_chain_sets.iter().flatten() { + *id_counts.entry(id.clone()).or_default() += 1; + } + + id_counts + .into_iter() + .filter_map(move |(id, count)| (count < num_sets).then_some(id)) +} + +/// Events are sorted from "earliest" to "latest". +/// +/// They are compared using the negative power level (reverse topological +/// ordering), the origin server timestamp and in case of a tie the `EventId`s +/// are compared lexicographically. +/// +/// The power level is negative because a higher power level is equated to an +/// earlier (further back in time) origin server timestamp. +#[tracing::instrument(level = "debug", skip_all)] +async fn reverse_topological_power_sort( + events_to_sort: Vec, + auth_diff: &HashSet, + fetch_event: &F, + parallel_fetches: usize, +) -> Result> +where + F: Fn(E::Id) -> Fut + Sync, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Send + Sync, +{ + debug!("reverse topological sort of power events"); + + let mut graph = HashMap::new(); + for event_id in events_to_sort { + add_event_and_auth_chain_to_graph(&mut graph, event_id, auth_diff, fetch_event).await; + } + + // This is used in the `key_fn` passed to the lexico_topo_sort fn + let event_to_pl = stream::iter(graph.keys()) + .map(|event_id| { + get_power_level_for_sender(event_id.clone(), fetch_event, parallel_fetches) + .map(move |res| res.map(|pl| (event_id, pl))) + }) + .buffer_unordered(parallel_fetches) + .try_fold(HashMap::new(), |mut event_to_pl, (event_id, pl)| { + debug!( + event_id = event_id.borrow().as_str(), + power_level = i64::from(pl), + "found the power level of an event's sender", + ); + + event_to_pl.insert(event_id.clone(), pl); + future::ok(event_to_pl) + }) + .boxed() + .await?; + + let event_to_pl = &event_to_pl; + let fetcher = |event_id: E::Id| async move { + let pl = *event_to_pl + .get(event_id.borrow()) + .ok_or_else(|| Error::NotFound(String::new()))?; + let ev = fetch_event(event_id) + .await + .ok_or_else(|| Error::NotFound(String::new()))?; + Ok((pl, ev.origin_server_ts())) + }; + + lexicographical_topological_sort(&graph, &fetcher).await +} + +/// Sorts the event graph based on number of outgoing/incoming edges. +/// +/// `key_fn` is used as to obtain the power level and age of an event for +/// breaking ties (together with the event ID). +#[tracing::instrument(level = "debug", skip_all)] +pub async fn lexicographical_topological_sort( + graph: &HashMap>, + key_fn: &F, +) -> Result> +where + F: Fn(Id) -> Fut + Sync, + Fut: Future> + Send, + Id: Borrow + Clone + Eq + Hash + Ord + Send, +{ + #[derive(PartialEq, Eq)] + struct TieBreaker<'a, Id> { + power_level: Int, + origin_server_ts: MilliSecondsSinceUnixEpoch, + event_id: &'a Id, + } + + impl Ord for TieBreaker<'_, Id> + where + Id: Ord, + { + fn cmp(&self, other: &Self) -> Ordering { + // NOTE: the power level comparison is "backwards" intentionally. + // See the "Mainline ordering" section of the Matrix specification + // around where it says the following: + // + // > for events `x` and `y`, `x < y` if [...] + // + // + other + .power_level + .cmp(&self.power_level) + .then(self.origin_server_ts.cmp(&other.origin_server_ts)) + .then(self.event_id.cmp(other.event_id)) + } + } + + impl PartialOrd for TieBreaker<'_, Id> + where + Id: Ord, + { + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } + } + + debug!("starting lexicographical topological sort"); + + // NOTE: an event that has no incoming edges happened most recently, + // and an event that has no outgoing edges happened least recently. + + // NOTE: this is basically Kahn's algorithm except we look at nodes with no + // outgoing edges, c.f. + // https://en.wikipedia.org/wiki/Topological_sorting#Kahn's_algorithm + + // outdegree_map is an event referring to the events before it, the + // more outdegree's the more recent the event. + let mut outdegree_map = graph.clone(); + + // The number of events that depend on the given event (the EventId key) + // How many events reference this event in the DAG as a parent + let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new(); + + // Vec of nodes that have zero out degree, least recent events. + let mut zero_outdegree = Vec::new(); + + for (node, edges) in graph { + if edges.is_empty() { + let (power_level, origin_server_ts) = key_fn(node.clone()).await?; + // The `Reverse` is because rusts `BinaryHeap` sorts largest -> smallest we need + // smallest -> largest + zero_outdegree.push(Reverse(TieBreaker { + power_level, + origin_server_ts, + event_id: node, + })); + } + + reverse_graph.entry(node).or_default(); + for edge in edges { + reverse_graph.entry(edge).or_default().insert(node); + } + } + + let mut heap = BinaryHeap::from(zero_outdegree); + + // We remove the oldest node (most incoming edges) and check against all other + let mut sorted = vec![]; + // Destructure the `Reverse` and take the smallest `node` each time + while let Some(Reverse(item)) = heap.pop() { + let node = item.event_id; + + for &parent in reverse_graph + .get(node) + .expect("EventId in heap is also in reverse_graph") + { + // The number of outgoing edges this node has + let out = outdegree_map + .get_mut(parent.borrow()) + .expect("outdegree_map knows of all referenced EventIds"); + + // Only push on the heap once older events have been cleared + out.remove(node.borrow()); + if out.is_empty() { + let (power_level, origin_server_ts) = key_fn(parent.clone()).await?; + heap.push(Reverse(TieBreaker { + power_level, + origin_server_ts, + event_id: parent, + })); + } + } + + // synapse yields we push then return the vec + sorted.push(node.clone()); + } + + Ok(sorted) +} + +/// Find the power level for the sender of `event_id` or return a default value +/// of zero. +/// +/// Do NOT use this any where but topological sort, we find the power level for +/// the eventId at the eventId's generation (we walk backwards to `EventId`s +/// most recent previous power level event). +async fn get_power_level_for_sender( + event_id: E::Id, + fetch_event: &F, + parallel_fetches: usize, +) -> serde_json::Result +where + F: Fn(E::Id) -> Fut + Sync, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Send, +{ + debug!("fetch event ({event_id}) senders power level"); + + let event = fetch_event(event_id.clone()).await; + + let auth_events = event.as_ref().map(Event::auth_events).into_iter().flatten(); + + let pl = stream::iter(auth_events) + .map(|aid| fetch_event(aid.clone())) + .buffer_unordered(parallel_fetches.min(5)) + .filter_map(future::ready) + .collect::>() + .boxed() + .await + .into_iter() + .find(|aev| is_type_and_key(aev, &TimelineEventType::RoomPowerLevels, "")); + + let content: PowerLevelsContentFields = match pl { + | None => return Ok(int!(0)), + | Some(ev) => from_json_str(ev.content().get())?, + }; + + if let Some(ev) = event { + if let Some(&user_level) = content.get_user_power(ev.sender()) { + debug!("found {} at power_level {user_level}", ev.sender()); + return Ok(user_level); + } + } + + Ok(content.users_default) +} + +/// Check the that each event is authenticated based on the events before it. +/// +/// ## Returns +/// +/// The `unconflicted_state` combined with the newly auth'ed events. So any +/// event that fails the `event_auth::auth_check` will be excluded from the +/// returned state map. +/// +/// For each `events_to_check` event we gather the events needed to auth it from +/// the the `fetch_event` closure and verify each event using the +/// `event_auth::auth_check` function. +async fn iterative_auth_check<'a, E, F, Fut, I>( + room_version: &RoomVersion, + events_to_check: I, + unconflicted_state: StateMap, + fetch_event: &F, + parallel_fetches: usize, +) -> Result> +where + F: Fn(E::Id) -> Fut + Sync, + Fut: Future> + Send, + E::Id: Borrow + Clone + Eq + Ord + Send + Sync + 'a, + I: Iterator + Debug + Send + 'a, + E: Event + Clone + Send + Sync, +{ + debug!("starting iterative auth check"); + trace!( + list = ?events_to_check, + "events to check" + ); + + let events_to_check: Vec<_> = stream::iter(events_to_check) + .map(Result::Ok) + .map_ok(|event_id| { + fetch_event(event_id.clone()).map(move |result| { + result.ok_or_else(|| Error::NotFound(format!("Failed to find {event_id}"))) + }) + }) + .try_buffer_unordered(parallel_fetches) + .try_collect() + .boxed() + .await?; + + let auth_event_ids: HashSet = events_to_check + .iter() + .flat_map(|event: &E| event.auth_events().map(Clone::clone)) + .collect(); + + let auth_events: HashMap = stream::iter(auth_event_ids.into_iter()) + .map(fetch_event) + .buffer_unordered(parallel_fetches) + .filter_map(future::ready) + .map(|auth_event| (auth_event.event_id().clone(), auth_event)) + .collect() + .boxed() + .await; + + let auth_events = &auth_events; + let mut resolved_state = unconflicted_state; + for event in &events_to_check { + let event_id = event.event_id(); + let state_key = event + .state_key() + .ok_or_else(|| Error::InvalidPdu("State event had no state key".to_owned()))?; + + let auth_types = auth_types_for_event( + event.event_type(), + event.sender(), + Some(state_key), + event.content(), + )?; + + let mut auth_state = StateMap::new(); + for aid in event.auth_events() { + if let Some(ev) = auth_events.get(aid.borrow()) { + //TODO: synapse checks "rejected_reason" which is most likely related to + // soft-failing + auth_state.insert( + ev.event_type() + .with_state_key(ev.state_key().ok_or_else(|| { + Error::InvalidPdu("State event had no state key".to_owned()) + })?), + ev.clone(), + ); + } else { + warn!(event_id = aid.borrow().as_str(), "missing auth event"); + } + } + + stream::iter( + auth_types + .iter() + .filter_map(|key| Some((key, resolved_state.get(key)?))), + ) + .filter_map(|(key, ev_id)| async move { + if let Some(event) = auth_events.get(ev_id.borrow()) { + Some((key, event.clone())) + } else { + Some((key, fetch_event(ev_id.clone()).await?)) + } + }) + .for_each(|(key, event)| { + //TODO: synapse checks "rejected_reason" is None here + auth_state.insert(key.to_owned(), event); + future::ready(()) + }) + .await; + + debug!("event to check {:?}", event.event_id()); + + // The key for this is (eventType + a state_key of the signed token not sender) + // so search for it + let current_third_party = auth_state.iter().find_map(|(_, pdu)| { + (*pdu.event_type() == TimelineEventType::RoomThirdPartyInvite).then_some(pdu) + }); + + let fetch_state = |ty: &StateEventType, key: &str| { + future::ready(auth_state.get(&ty.with_state_key(key))) + }; + + if auth_check(room_version, &event, current_third_party.as_ref(), fetch_state).await? { + // add event to resolved state map + resolved_state.insert(event.event_type().with_state_key(state_key), event_id.clone()); + } else { + // synapse passes here on AuthError. We do not add this event to resolved_state. + warn!("event {event_id} failed the authentication check"); + } + } + + Ok(resolved_state) +} + +/// Returns the sorted `to_sort` list of `EventId`s based on a mainline sort +/// using the depth of `resolved_power_level`, the server timestamp, and the +/// eventId. +/// +/// The depth of the given event is calculated based on the depth of it's +/// closest "parent" power_level event. If there have been two power events the +/// after the most recent are depth 0, the events before (with the first power +/// level as a parent) will be marked as depth 1. depth 1 is "older" than depth +/// 0. +async fn mainline_sort( + to_sort: &[E::Id], + resolved_power_level: Option, + fetch_event: &F, + parallel_fetches: usize, +) -> Result> +where + F: Fn(E::Id) -> Fut + Sync, + Fut: Future> + Send, + E: Event + Clone + Send + Sync, + E::Id: Borrow + Clone + Send + Sync, +{ + debug!("mainline sort of events"); + + // There are no EventId's to sort, bail. + if to_sort.is_empty() { + return Ok(vec![]); + } + + let mut mainline = vec![]; + let mut pl = resolved_power_level; + while let Some(p) = pl { + mainline.push(p.clone()); + + let event = fetch_event(p.clone()) + .await + .ok_or_else(|| Error::NotFound(format!("Failed to find {p}")))?; + pl = None; + for aid in event.auth_events() { + let ev = fetch_event(aid.clone()) + .await + .ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?; + if is_type_and_key(&ev, &TimelineEventType::RoomPowerLevels, "") { + pl = Some(aid.to_owned()); + break; + } + } + } + + let mainline_map = mainline + .iter() + .rev() + .enumerate() + .map(|(idx, eid)| ((*eid).clone(), idx)) + .collect::>(); + + let order_map = stream::iter(to_sort.iter()) + .map(|ev_id| { + fetch_event(ev_id.clone()).map(move |event| event.map(|event| (event, ev_id))) + }) + .buffer_unordered(parallel_fetches) + .filter_map(future::ready) + .map(|(event, ev_id)| { + get_mainline_depth(Some(event.clone()), &mainline_map, fetch_event) + .map_ok(move |depth| (depth, event, ev_id)) + .map(Result::ok) + }) + .buffer_unordered(parallel_fetches) + .filter_map(future::ready) + .fold(HashMap::new(), |mut order_map, (depth, event, ev_id)| { + order_map.insert(ev_id, (depth, event.origin_server_ts(), ev_id)); + future::ready(order_map) + }) + .boxed() + .await; + + // Sort the event_ids by their depth, timestamp and EventId + // unwrap is OK order map and sort_event_ids are from to_sort (the same Vec) + let mut sort_event_ids = order_map.keys().map(|&k| k.clone()).collect::>(); + sort_event_ids.sort_by_key(|sort_id| &order_map[sort_id]); + + Ok(sort_event_ids) +} + +/// Get the mainline depth from the `mainline_map` or finds a power_level event +/// that has an associated mainline depth. +async fn get_mainline_depth( + mut event: Option, + mainline_map: &HashMap, + fetch_event: &F, +) -> Result +where + F: Fn(E::Id) -> Fut + Sync, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Send, +{ + while let Some(sort_ev) = event { + debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline"); + let id = sort_ev.event_id(); + if let Some(depth) = mainline_map.get(id.borrow()) { + return Ok(*depth); + } + + event = None; + for aid in sort_ev.auth_events() { + let aev = fetch_event(aid.clone()) + .await + .ok_or_else(|| Error::NotFound(format!("Failed to find {aid}")))?; + if is_type_and_key(&aev, &TimelineEventType::RoomPowerLevels, "") { + event = Some(aev); + break; + } + } + } + // Did not find a power level event so we default to zero + Ok(0) +} + +async fn add_event_and_auth_chain_to_graph( + graph: &mut HashMap>, + event_id: E::Id, + auth_diff: &HashSet, + fetch_event: &F, +) where + F: Fn(E::Id) -> Fut, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Clone + Send, +{ + let mut state = vec![event_id]; + while let Some(eid) = state.pop() { + graph.entry(eid.clone()).or_default(); + let event = fetch_event(eid.clone()).await; + let auth_events = event.as_ref().map(Event::auth_events).into_iter().flatten(); + + // Prefer the store to event as the store filters dedups the events + for aid in auth_events { + if auth_diff.contains(aid.borrow()) { + if !graph.contains_key(aid.borrow()) { + state.push(aid.to_owned()); + } + + // We just inserted this at the start of the while loop + graph.get_mut(eid.borrow()).unwrap().insert(aid.to_owned()); + } + } + } +} + +async fn is_power_event_id(event_id: &E::Id, fetch: &F) -> bool +where + F: Fn(E::Id) -> Fut + Sync, + Fut: Future> + Send, + E: Event + Send, + E::Id: Borrow + Send, +{ + match fetch(event_id.clone()).await.as_ref() { + | Some(state) => is_power_event(state), + | _ => false, + } +} + +fn is_type_and_key(ev: impl Event, ev_type: &TimelineEventType, state_key: &str) -> bool { + ev.event_type() == ev_type && ev.state_key() == Some(state_key) +} + +fn is_power_event(event: impl Event) -> bool { + match event.event_type() { + | TimelineEventType::RoomPowerLevels + | TimelineEventType::RoomJoinRules + | TimelineEventType::RoomCreate => event.state_key() == Some(""), + | TimelineEventType::RoomMember => { + if let Ok(content) = from_json_str::(event.content().get()) { + if [MembershipState::Leave, MembershipState::Ban].contains(&content.membership) { + return Some(event.sender().as_str()) != event.state_key(); + } + } + + false + }, + | _ => false, + } +} + +/// Convenience trait for adding event type plus state key to state maps. +pub trait EventTypeExt { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String); +} + +impl EventTypeExt for StateEventType { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + (self, state_key.into()) + } +} + +impl EventTypeExt for TimelineEventType { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + (self.to_string().into(), state_key.into()) + } +} + +impl EventTypeExt for &T +where + T: EventTypeExt + Clone, +{ + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + self.to_owned().with_state_key(state_key) + } +} + +#[cfg(test)] +mod tests { + use std::{ + collections::{HashMap, HashSet}, + sync::Arc, + }; + + use maplit::{hashmap, hashset}; + use rand::seq::SliceRandom; + use ruma::{ + events::{ + room::join_rules::{JoinRule, RoomJoinRulesEventContent}, + StateEventType, TimelineEventType, + }, + int, uint, + }; + use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId}; + use serde_json::{json, value::to_raw_value as to_raw_json_value}; + use tracing::debug; + + use crate::{ + is_power_event, + room_version::RoomVersion, + test_utils::{ + alice, bob, charlie, do_check, ella, event_id, member_content_ban, + member_content_join, room_id, to_init_pdu_event, to_pdu_event, zara, PduEvent, + TestStore, INITIAL_EVENTS, + }, + Event, EventTypeExt, StateMap, + }; + + async fn test_event_sort() { + use futures::future::ready; + + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + let events = INITIAL_EVENTS(); + + let event_map = events + .values() + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.clone())) + .collect::>(); + + let auth_chain: HashSet = HashSet::new(); + + let power_events = event_map + .values() + .filter(|&pdu| is_power_event(&**pdu)) + .map(|pdu| pdu.event_id.clone()) + .collect::>(); + + let fetcher = |id| ready(events.get(&id).cloned()); + let sorted_power_events = + crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1) + .await + .unwrap(); + + let resolved_power = crate::iterative_auth_check( + &RoomVersion::V6, + sorted_power_events.iter(), + HashMap::new(), // unconflicted events + &fetcher, + 1, + ) + .await + .expect("iterative auth check failed on resolved events"); + + // don't remove any events so we know it sorts them all correctly + let mut events_to_sort = events.keys().cloned().collect::>(); + + events_to_sort.shuffle(&mut rand::thread_rng()); + + let power_level = resolved_power + .get(&(StateEventType::RoomPowerLevels, "".to_owned())) + .cloned(); + + let sorted_event_ids = crate::mainline_sort(&events_to_sort, power_level, &fetcher, 1) + .await + .unwrap(); + + assert_eq!( + vec![ + "$CREATE:foo", + "$IMA:foo", + "$IPOWER:foo", + "$IJR:foo", + "$IMB:foo", + "$IMC:foo", + "$START:foo", + "$END:foo" + ], + sorted_event_ids + .iter() + .map(|id| id.to_string()) + .collect::>() + ); + } + + #[tokio::test] + async fn test_sort() { + for _ in 0..20 { + // since we shuffle the eventIds before we sort them introducing randomness + // seems like we should test this a few times + test_event_sort().await; + } + } + + #[tokio::test] + async fn ban_vs_power_level() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + + let events = &[ + to_init_pdu_event( + "PA", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + ), + to_init_pdu_event( + "MA", + alice(), + TimelineEventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + ), + to_init_pdu_event( + "MB", + alice(), + TimelineEventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_ban(), + ), + to_init_pdu_event( + "PB", + bob(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + ), + ]; + + let edges = vec![vec!["END", "MB", "MA", "PA", "START"], vec!["END", "PA", "PB"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["PA", "MA", "MB"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check(events, edges, expected_state_ids).await; + } + + #[tokio::test] + async fn topic_basic() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + + let events = &[ + to_init_pdu_event( + "T1", + alice(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + ), + to_init_pdu_event( + "PA1", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + ), + to_init_pdu_event( + "T2", + alice(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + ), + to_init_pdu_event( + "PA2", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(), + ), + to_init_pdu_event( + "PB", + bob(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + ), + to_init_pdu_event( + "T3", + bob(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + ), + ]; + + let edges = + vec![vec!["END", "PA2", "T2", "PA1", "T1", "START"], vec!["END", "T3", "PB", "PA1"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["PA2", "T2"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check(events, edges, expected_state_ids).await; + } + + #[tokio::test] + async fn topic_reset() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + + let events = &[ + to_init_pdu_event( + "T1", + alice(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + ), + to_init_pdu_event( + "PA", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + ), + to_init_pdu_event( + "T2", + bob(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + ), + to_init_pdu_event( + "MB", + alice(), + TimelineEventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_ban(), + ), + ]; + + let edges = vec![vec!["END", "MB", "T2", "PA", "T1", "START"], vec!["END", "T1"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["T1", "MB", "PA"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check(events, edges, expected_state_ids).await; + } + + #[tokio::test] + async fn join_rule_evasion() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + + let events = &[ + to_init_pdu_event( + "JR", + alice(), + TimelineEventType::RoomJoinRules, + Some(""), + to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Private)).unwrap(), + ), + to_init_pdu_event( + "ME", + ella(), + TimelineEventType::RoomMember, + Some(ella().to_string().as_str()), + member_content_join(), + ), + ]; + + let edges = vec![vec!["END", "JR", "START"], vec!["END", "ME", "START"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec![event_id("JR")]; + + do_check(events, edges, expected_state_ids).await; + } + + #[tokio::test] + async fn offtopic_power_level() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + + let events = &[ + to_init_pdu_event( + "PA", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + ), + to_init_pdu_event( + "PB", + bob(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value( + &json!({ "users": { alice(): 100, bob(): 50, charlie(): 50 } }), + ) + .unwrap(), + ), + to_init_pdu_event( + "PC", + charlie(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50, charlie(): 0 } })) + .unwrap(), + ), + ]; + + let edges = vec![vec!["END", "PC", "PB", "PA", "START"], vec!["END", "PA"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["PC"].into_iter().map(event_id).collect::>(); + + do_check(events, edges, expected_state_ids).await; + } + + #[tokio::test] + async fn topic_setting() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + + let events = &[ + to_init_pdu_event( + "T1", + alice(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + ), + to_init_pdu_event( + "PA1", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + ), + to_init_pdu_event( + "T2", + alice(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + ), + to_init_pdu_event( + "PA2", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 0 } })).unwrap(), + ), + to_init_pdu_event( + "PB", + bob(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + ), + to_init_pdu_event( + "T3", + bob(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + ), + to_init_pdu_event( + "MZ1", + zara(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + ), + to_init_pdu_event( + "T4", + alice(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + ), + ]; + + let edges = vec![vec!["END", "T4", "MZ1", "PA2", "T2", "PA1", "T1", "START"], vec![ + "END", "MZ1", "T3", "PB", "PA1", + ]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["T4", "PA2"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check(events, edges, expected_state_ids).await; + } + + #[tokio::test] + async fn test_event_map_none() { + use futures::future::ready; + + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + + let mut store = TestStore::(hashmap! {}); + + // build up the DAG + let (state_at_bob, state_at_charlie, expected) = store.set_up(); + + let ev_map = store.0.clone(); + let fetcher = |id| ready(ev_map.get(&id).cloned()); + + let exists = |id: ::Id| ready(ev_map.get(&*id).is_some()); + + let state_sets = [state_at_bob, state_at_charlie]; + let auth_chain: Vec<_> = state_sets + .iter() + .map(|map| { + store + .auth_event_ids(room_id(), map.values().cloned().collect()) + .unwrap() + }) + .collect(); + + let resolved = match crate::resolve( + &RoomVersionId::V2, + &state_sets, + &auth_chain, + &fetcher, + &exists, + 1, + ) + .await + { + | Ok(state) => state, + | Err(e) => panic!("{e}"), + }; + + assert_eq!(expected, resolved); + } + + #[tokio::test] + async fn test_lexicographical_sort() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + + let graph = hashmap! { + event_id("l") => hashset![event_id("o")], + event_id("m") => hashset![event_id("n"), event_id("o")], + event_id("n") => hashset![event_id("o")], + event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges + event_id("p") => hashset![event_id("o")], + }; + + let res = crate::lexicographical_topological_sort(&graph, &|_id| async { + Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) + }) + .await + .unwrap(); + + assert_eq!( + vec!["o", "l", "n", "m", "p"], + res.iter() + .map(ToString::to_string) + .map(|s| s.replace('$', "").replace(":foo", "")) + .collect::>() + ); + } + + #[tokio::test] + async fn ban_with_auth_chains() { + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + let ban = BAN_STATE_SET(); + + let edges = vec![vec!["END", "MB", "PA", "START"], vec!["END", "IME", "MB"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["PA", "MB"] + .into_iter() + .map(event_id) + .collect::>(); + + do_check(&ban.values().cloned().collect::>(), edges, expected_state_ids).await; + } + + #[tokio::test] + async fn ban_with_auth_chains2() { + use futures::future::ready; + + let _ = tracing::subscriber::set_default( + tracing_subscriber::fmt().with_test_writer().finish(), + ); + let init = INITIAL_EVENTS(); + let ban = BAN_STATE_SET(); + + let mut inner = init.clone(); + inner.extend(ban); + let store = TestStore(inner.clone()); + + let state_set_a = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("MB")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone())) + .collect::>(); + + let state_set_b = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("IME")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id.clone())) + .collect::>(); + + let ev_map = &store.0; + let state_sets = [state_set_a, state_set_b]; + let auth_chain: Vec<_> = state_sets + .iter() + .map(|map| { + store + .auth_event_ids(room_id(), map.values().cloned().collect()) + .unwrap() + }) + .collect(); + + let fetcher = |id: ::Id| ready(ev_map.get(&id).cloned()); + let exists = |id: ::Id| ready(ev_map.get(&id).is_some()); + let resolved = match crate::resolve( + &RoomVersionId::V6, + &state_sets, + &auth_chain, + &fetcher, + &exists, + 1, + ) + .await + { + | Ok(state) => state, + | Err(e) => panic!("{e}"), + }; + + debug!( + resolved = ?resolved + .iter() + .map(|((ty, key), id)| format!("(({ty}{key:?}), {id})")) + .collect::>(), + "resolved state", + ); + + let expected = [ + "$CREATE:foo", + "$IJR:foo", + "$PA:foo", + "$IMA:foo", + "$IMB:foo", + "$IMC:foo", + "$MB:foo", + ]; + + for id in expected.iter().map(|i| event_id(i)) { + // make sure our resolved events are equal to the expected list + assert!(resolved.values().any(|eid| eid == &id) || init.contains_key(&id), "{id}"); + } + assert_eq!(expected.len(), resolved.len()); + } + + #[tokio::test] + async fn join_rule_with_auth_chain() { + let join_rule = JOIN_RULE(); + + let edges = vec![vec!["END", "JR", "START"], vec!["END", "IMZ", "START"]] + .into_iter() + .map(|list| list.into_iter().map(event_id).collect::>()) + .collect::>(); + + let expected_state_ids = vec!["JR"].into_iter().map(event_id).collect::>(); + + do_check(&join_rule.values().cloned().collect::>(), edges, expected_state_ids) + .await; + } + + #[allow(non_snake_case)] + fn BAN_STATE_SET() -> HashMap> { + vec![ + to_pdu_event( + "PA", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + &["CREATE", "IMA", "IPOWER"], // auth_events + &["START"], // prev_events + ), + to_pdu_event( + "PB", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + &["CREATE", "IMA", "IPOWER"], + &["END"], + ), + to_pdu_event( + "MB", + alice(), + TimelineEventType::RoomMember, + Some(ella().as_str()), + member_content_ban(), + &["CREATE", "IMA", "PB"], + &["PA"], + ), + to_pdu_event( + "IME", + ella(), + TimelineEventType::RoomMember, + Some(ella().as_str()), + member_content_join(), + &["CREATE", "IJR", "PA"], + &["MB"], + ), + ] + .into_iter() + .map(|ev| (ev.event_id.clone(), ev)) + .collect() + } + + #[allow(non_snake_case)] + fn JOIN_RULE() -> HashMap> { + vec![ + to_pdu_event( + "JR", + alice(), + TimelineEventType::RoomJoinRules, + Some(""), + to_raw_json_value(&json!({ "join_rule": "invite" })).unwrap(), + &["CREATE", "IMA", "IPOWER"], + &["START"], + ), + to_pdu_event( + "IMZ", + zara(), + TimelineEventType::RoomPowerLevels, + Some(zara().as_str()), + member_content_join(), + &["CREATE", "JR", "IPOWER"], + &["START"], + ), + ] + .into_iter() + .map(|ev| (ev.event_id.clone(), ev)) + .collect() + } + + macro_rules! state_set { + ($($kind:expr => $key:expr => $id:expr),* $(,)?) => {{ + #[allow(unused_mut)] + let mut x = StateMap::new(); + $( + x.insert(($kind, $key.to_owned()), $id); + )* + x + }}; + } + + #[test] + fn separate_unique_conflicted() { + let (unconflicted, conflicted) = super::separate( + [ + state_set![StateEventType::RoomMember => "@a:hs1" => 0], + state_set![StateEventType::RoomMember => "@b:hs1" => 1], + state_set![StateEventType::RoomMember => "@c:hs1" => 2], + ] + .iter(), + ); + + assert_eq!(unconflicted, StateMap::new()); + assert_eq!(conflicted, state_set![ + StateEventType::RoomMember => "@a:hs1" => vec![0], + StateEventType::RoomMember => "@b:hs1" => vec![1], + StateEventType::RoomMember => "@c:hs1" => vec![2], + ],); + } + + #[test] + fn separate_conflicted() { + let (unconflicted, mut conflicted) = super::separate( + [ + state_set![StateEventType::RoomMember => "@a:hs1" => 0], + state_set![StateEventType::RoomMember => "@a:hs1" => 1], + state_set![StateEventType::RoomMember => "@a:hs1" => 2], + ] + .iter(), + ); + + // HashMap iteration order is random, so sort this before asserting on it + for v in conflicted.values_mut() { + v.sort_unstable(); + } + + assert_eq!(unconflicted, StateMap::new()); + assert_eq!(conflicted, state_set![ + StateEventType::RoomMember => "@a:hs1" => vec![0, 1, 2], + ],); + } + + #[test] + fn separate_unconflicted() { + let (unconflicted, conflicted) = super::separate( + [ + state_set![StateEventType::RoomMember => "@a:hs1" => 0], + state_set![StateEventType::RoomMember => "@a:hs1" => 0], + state_set![StateEventType::RoomMember => "@a:hs1" => 0], + ] + .iter(), + ); + + assert_eq!(unconflicted, state_set![ + StateEventType::RoomMember => "@a:hs1" => 0, + ],); + assert_eq!(conflicted, StateMap::new()); + } + + #[test] + fn separate_mixed() { + let (unconflicted, conflicted) = super::separate( + [ + state_set![StateEventType::RoomMember => "@a:hs1" => 0], + state_set![ + StateEventType::RoomMember => "@a:hs1" => 0, + StateEventType::RoomMember => "@b:hs1" => 1, + ], + state_set![ + StateEventType::RoomMember => "@a:hs1" => 0, + StateEventType::RoomMember => "@c:hs1" => 2, + ], + ] + .iter(), + ); + + assert_eq!(unconflicted, state_set![ + StateEventType::RoomMember => "@a:hs1" => 0, + ],); + assert_eq!(conflicted, state_set![ + StateEventType::RoomMember => "@b:hs1" => vec![1], + StateEventType::RoomMember => "@c:hs1" => vec![2], + ],); + } +} diff --git a/src/core/state_res/outcomes.txt b/src/core/state_res/outcomes.txt new file mode 100644 index 00000000..0fa1c734 --- /dev/null +++ b/src/core/state_res/outcomes.txt @@ -0,0 +1,104 @@ +11/29/2020 BRANCH: timo-spec-comp REV: d2a85669cc6056679ce6ca0fde4658a879ad2b08 +lexicographical topological sort + time: [1.7123 us 1.7157 us 1.7199 us] + change: [-1.7584% -1.5433% -1.3205%] (p = 0.00 < 0.05) + Performance has improved. +Found 8 outliers among 100 measurements (8.00%) + 2 (2.00%) low mild + 5 (5.00%) high mild + 1 (1.00%) high severe + +resolve state of 5 events one fork + time: [10.981 us 10.998 us 11.020 us] +Found 3 outliers among 100 measurements (3.00%) + 3 (3.00%) high mild + +resolve state of 10 events 3 conflicting + time: [26.858 us 26.946 us 27.037 us] + +11/29/2020 BRANCH: event-trait REV: f0eb1310efd49d722979f57f20bd1ac3592b0479 +lexicographical topological sort + time: [1.7686 us 1.7738 us 1.7810 us] + change: [-3.2752% -2.4634% -1.7635%] (p = 0.00 < 0.05) + Performance has improved. +Found 1 outliers among 100 measurements (1.00%) + 1 (1.00%) high severe + +resolve state of 5 events one fork + time: [10.643 us 10.656 us 10.669 us] + change: [-4.9990% -3.8078% -2.8319%] (p = 0.00 < 0.05) + Performance has improved. +Found 1 outliers among 100 measurements (1.00%) + 1 (1.00%) high severe + +resolve state of 10 events 3 conflicting + time: [29.149 us 29.252 us 29.375 us] + change: [-0.8433% -0.3270% +0.2656%] (p = 0.25 > 0.05) + No change in performance detected. +Found 1 outliers among 100 measurements (1.00%) + 1 (1.00%) high mild + +4/26/2020 BRANCH: fix-test-serde REV: +lexicographical topological sort + time: [1.6793 us 1.6823 us 1.6857 us] +Found 9 outliers among 100 measurements (9.00%) + 1 (1.00%) low mild + 4 (4.00%) high mild + 4 (4.00%) high severe + +resolve state of 5 events one fork + time: [9.9993 us 10.062 us 10.159 us] +Found 9 outliers among 100 measurements (9.00%) + 7 (7.00%) high mild + 2 (2.00%) high severe + +resolve state of 10 events 3 conflicting + time: [26.004 us 26.092 us 26.195 us] +Found 16 outliers among 100 measurements (16.00%) + 11 (11.00%) high mild + 5 (5.00%) high severe + +6/30/2021 BRANCH: state-closure REV: 174c3e2a72232ad75b3fb14b3551f5f746f4fe84 +lexicographical topological sort + time: [1.5496 us 1.5536 us 1.5586 us] +Found 9 outliers among 100 measurements (9.00%) + 1 (1.00%) low mild + 1 (1.00%) high mild + 7 (7.00%) high severe + +resolve state of 5 events one fork + time: [10.319 us 10.333 us 10.347 us] +Found 2 outliers among 100 measurements (2.00%) + 2 (2.00%) high severe + +resolve state of 10 events 3 conflicting + time: [25.770 us 25.805 us 25.839 us] +Found 7 outliers among 100 measurements (7.00%) + 5 (5.00%) high mild + 2 (2.00%) high severe + +7/20/2021 BRANCH stateres-result REV: +This marks the switch to HashSet/Map +lexicographical topological sort + time: [1.8122 us 1.8177 us 1.8233 us] + change: [+15.205% +15.919% +16.502%] (p = 0.00 < 0.05) + Performance has regressed. +Found 7 outliers among 100 measurements (7.00%) + 5 (5.00%) high mild + 2 (2.00%) high severe + +resolve state of 5 events one fork + time: [11.966 us 12.010 us 12.059 us] + change: [+16.089% +16.730% +17.469%] (p = 0.00 < 0.05) + Performance has regressed. +Found 7 outliers among 100 measurements (7.00%) + 3 (3.00%) high mild + 4 (4.00%) high severe + +resolve state of 10 events 3 conflicting + time: [29.092 us 29.201 us 29.311 us] + change: [+12.447% +12.847% +13.280%] (p = 0.00 < 0.05) + Performance has regressed. +Found 9 outliers among 100 measurements (9.00%) + 6 (6.00%) high mild + 3 (3.00%) high severe diff --git a/src/core/state_res/power_levels.rs b/src/core/state_res/power_levels.rs new file mode 100644 index 00000000..e1768574 --- /dev/null +++ b/src/core/state_res/power_levels.rs @@ -0,0 +1,256 @@ +use std::collections::BTreeMap; + +use ruma::{ + events::{room::power_levels::RoomPowerLevelsEventContent, TimelineEventType}, + power_levels::{default_power_level, NotificationPowerLevels}, + serde::{ + deserialize_v1_powerlevel, vec_deserialize_int_powerlevel_values, + vec_deserialize_v1_powerlevel_values, + }, + Int, OwnedUserId, UserId, +}; +use serde::Deserialize; +use serde_json::{from_str as from_json_str, Error}; +use tracing::error; + +use super::{Result, RoomVersion}; + +#[derive(Deserialize)] +struct IntRoomPowerLevelsEventContent { + #[serde(default = "default_power_level")] + ban: Int, + + #[serde(default)] + events: BTreeMap, + + #[serde(default)] + events_default: Int, + + #[serde(default)] + invite: Int, + + #[serde(default = "default_power_level")] + kick: Int, + + #[serde(default = "default_power_level")] + redact: Int, + + #[serde(default = "default_power_level")] + state_default: Int, + + #[serde(default)] + users: BTreeMap, + + #[serde(default)] + users_default: Int, + + #[serde(default)] + notifications: IntNotificationPowerLevels, +} + +impl From for RoomPowerLevelsEventContent { + fn from(int_pl: IntRoomPowerLevelsEventContent) -> Self { + let IntRoomPowerLevelsEventContent { + ban, + events, + events_default, + invite, + kick, + redact, + state_default, + users, + users_default, + notifications, + } = int_pl; + + let mut pl = Self::new(); + pl.ban = ban; + pl.events = events; + pl.events_default = events_default; + pl.invite = invite; + pl.kick = kick; + pl.redact = redact; + pl.state_default = state_default; + pl.users = users; + pl.users_default = users_default; + pl.notifications = notifications.into(); + + pl + } +} + +#[derive(Deserialize)] +struct IntNotificationPowerLevels { + #[serde(default = "default_power_level")] + room: Int, +} + +impl Default for IntNotificationPowerLevels { + fn default() -> Self { Self { room: default_power_level() } } +} + +impl From for NotificationPowerLevels { + fn from(int_notif: IntNotificationPowerLevels) -> Self { + let mut notif = Self::new(); + notif.room = int_notif.room; + + notif + } +} + +#[inline] +pub(crate) fn deserialize_power_levels( + content: &str, + room_version: &RoomVersion, +) -> Option { + if room_version.integer_power_levels { + deserialize_integer_power_levels(content) + } else { + deserialize_legacy_power_levels(content) + } +} + +fn deserialize_integer_power_levels(content: &str) -> Option { + match from_json_str::(content) { + | Ok(content) => Some(content.into()), + | Err(_) => { + error!("m.room.power_levels event is not valid with integer values"); + None + }, + } +} + +fn deserialize_legacy_power_levels(content: &str) -> Option { + match from_json_str(content) { + | Ok(content) => Some(content), + | Err(_) => { + error!( + "m.room.power_levels event is not valid with integer or string integer values" + ); + None + }, + } +} + +#[derive(Deserialize)] +pub(crate) struct PowerLevelsContentFields { + #[serde(default, deserialize_with = "vec_deserialize_v1_powerlevel_values")] + pub(crate) users: Vec<(OwnedUserId, Int)>, + + #[serde(default, deserialize_with = "deserialize_v1_powerlevel")] + pub(crate) users_default: Int, +} + +impl PowerLevelsContentFields { + pub(crate) fn get_user_power(&self, user_id: &UserId) -> Option<&Int> { + let comparator = |item: &(OwnedUserId, Int)| { + let item: &UserId = &item.0; + item.cmp(user_id) + }; + + self.users + .binary_search_by(comparator) + .ok() + .and_then(|idx| self.users.get(idx).map(|item| &item.1)) + } +} + +#[derive(Deserialize)] +struct IntPowerLevelsContentFields { + #[serde(default, deserialize_with = "vec_deserialize_int_powerlevel_values")] + users: Vec<(OwnedUserId, Int)>, + + #[serde(default)] + users_default: Int, +} + +impl From for PowerLevelsContentFields { + fn from(pl: IntPowerLevelsContentFields) -> Self { + let IntPowerLevelsContentFields { users, users_default } = pl; + Self { users, users_default } + } +} + +#[inline] +pub(crate) fn deserialize_power_levels_content_fields( + content: &str, + room_version: &RoomVersion, +) -> Result { + if room_version.integer_power_levels { + deserialize_integer_power_levels_content_fields(content) + } else { + deserialize_legacy_power_levels_content_fields(content) + } +} + +fn deserialize_integer_power_levels_content_fields( + content: &str, +) -> Result { + from_json_str::(content).map(Into::into) +} + +fn deserialize_legacy_power_levels_content_fields( + content: &str, +) -> Result { + from_json_str(content) +} + +#[derive(Deserialize)] +pub(crate) struct PowerLevelsContentInvite { + #[serde(default, deserialize_with = "deserialize_v1_powerlevel")] + pub(crate) invite: Int, +} + +#[derive(Deserialize)] +struct IntPowerLevelsContentInvite { + #[serde(default)] + invite: Int, +} + +impl From for PowerLevelsContentInvite { + fn from(pl: IntPowerLevelsContentInvite) -> Self { + let IntPowerLevelsContentInvite { invite } = pl; + Self { invite } + } +} + +pub(crate) fn deserialize_power_levels_content_invite( + content: &str, + room_version: &RoomVersion, +) -> Result { + if room_version.integer_power_levels { + from_json_str::(content).map(Into::into) + } else { + from_json_str(content) + } +} + +#[derive(Deserialize)] +pub(crate) struct PowerLevelsContentRedact { + #[serde(default = "default_power_level", deserialize_with = "deserialize_v1_powerlevel")] + pub(crate) redact: Int, +} + +#[derive(Deserialize)] +pub(crate) struct IntPowerLevelsContentRedact { + #[serde(default = "default_power_level")] + redact: Int, +} + +impl From for PowerLevelsContentRedact { + fn from(pl: IntPowerLevelsContentRedact) -> Self { + let IntPowerLevelsContentRedact { redact } = pl; + Self { redact } + } +} + +pub(crate) fn deserialize_power_levels_content_redact( + content: &str, + room_version: &RoomVersion, +) -> Result { + if room_version.integer_power_levels { + from_json_str::(content).map(Into::into) + } else { + from_json_str(content) + } +} diff --git a/src/core/state_res/room_version.rs b/src/core/state_res/room_version.rs new file mode 100644 index 00000000..e1b0afe1 --- /dev/null +++ b/src/core/state_res/room_version.rs @@ -0,0 +1,149 @@ +use ruma::RoomVersionId; + +use super::{Error, Result}; + +#[derive(Debug)] +#[allow(clippy::exhaustive_enums)] +pub enum RoomDisposition { + /// A room version that has a stable specification. + Stable, + /// A room version that is not yet fully specified. + Unstable, +} + +#[derive(Debug)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +pub enum EventFormatVersion { + /// $id:server event id format + V1, + /// MSC1659-style $hash event id format: introduced for room v3 + V2, + /// MSC1884-style $hash format: introduced for room v4 + V3, +} + +#[derive(Debug)] +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +pub enum StateResolutionVersion { + /// State resolution for rooms at version 1. + V1, + /// State resolution for room at version 2 or later. + V2, +} + +#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] +pub struct RoomVersion { + /// The stability of this room. + pub disposition: RoomDisposition, + /// The format of the EventId. + pub event_format: EventFormatVersion, + /// Which state resolution algorithm is used. + pub state_res: StateResolutionVersion, + // FIXME: not sure what this one means? + pub enforce_key_validity: bool, + + /// `m.room.aliases` had special auth rules and redaction rules + /// before room version 6. + /// + /// before MSC2261/MSC2432, + pub special_case_aliases_auth: bool, + /// Strictly enforce canonical json, do not allow: + /// * Integers outside the range of [-2 ^ 53 + 1, 2 ^ 53 - 1] + /// * Floats + /// * NaN, Infinity, -Infinity + pub strict_canonicaljson: bool, + /// Verify notifications key while checking m.room.power_levels. + /// + /// bool: MSC2209: Check 'notifications' + pub limit_notifications_power_levels: bool, + /// Extra rules when verifying redaction events. + pub extra_redaction_checks: bool, + /// Allow knocking in event authentication. + /// + /// See [room v7 specification](https://spec.matrix.org/latest/rooms/v7/) for more information. + pub allow_knocking: bool, + /// Adds support for the restricted join rule. + /// + /// See: [MSC3289](https://github.com/matrix-org/matrix-spec-proposals/pull/3289) for more information. + pub restricted_join_rules: bool, + /// Adds support for the knock_restricted join rule. + /// + /// See: [MSC3787](https://github.com/matrix-org/matrix-spec-proposals/pull/3787) for more information. + pub knock_restricted_join_rule: bool, + /// Enforces integer power levels. + /// + /// See: [MSC3667](https://github.com/matrix-org/matrix-spec-proposals/pull/3667) for more information. + pub integer_power_levels: bool, + /// Determine the room creator using the `m.room.create` event's `sender`, + /// instead of the event content's `creator` field. + /// + /// See: [MSC2175](https://github.com/matrix-org/matrix-spec-proposals/pull/2175) for more information. + pub use_room_create_sender: bool, +} + +impl RoomVersion { + pub const V1: Self = Self { + disposition: RoomDisposition::Stable, + event_format: EventFormatVersion::V1, + state_res: StateResolutionVersion::V1, + enforce_key_validity: false, + special_case_aliases_auth: true, + strict_canonicaljson: false, + limit_notifications_power_levels: false, + extra_redaction_checks: true, + allow_knocking: false, + restricted_join_rules: false, + knock_restricted_join_rule: false, + integer_power_levels: false, + use_room_create_sender: false, + }; + pub const V10: Self = Self { + knock_restricted_join_rule: true, + integer_power_levels: true, + ..Self::V9 + }; + pub const V11: Self = Self { + use_room_create_sender: true, + ..Self::V10 + }; + pub const V2: Self = Self { + state_res: StateResolutionVersion::V2, + ..Self::V1 + }; + pub const V3: Self = Self { + event_format: EventFormatVersion::V2, + extra_redaction_checks: false, + ..Self::V2 + }; + pub const V4: Self = Self { + event_format: EventFormatVersion::V3, + ..Self::V3 + }; + pub const V5: Self = Self { enforce_key_validity: true, ..Self::V4 }; + pub const V6: Self = Self { + special_case_aliases_auth: false, + strict_canonicaljson: true, + limit_notifications_power_levels: true, + ..Self::V5 + }; + pub const V7: Self = Self { allow_knocking: true, ..Self::V6 }; + pub const V8: Self = Self { restricted_join_rules: true, ..Self::V7 }; + pub const V9: Self = Self::V8; + + pub fn new(version: &RoomVersionId) -> Result { + Ok(match version { + | RoomVersionId::V1 => Self::V1, + | RoomVersionId::V2 => Self::V2, + | RoomVersionId::V3 => Self::V3, + | RoomVersionId::V4 => Self::V4, + | RoomVersionId::V5 => Self::V5, + | RoomVersionId::V6 => Self::V6, + | RoomVersionId::V7 => Self::V7, + | RoomVersionId::V8 => Self::V8, + | RoomVersionId::V9 => Self::V9, + | RoomVersionId::V10 => Self::V10, + | RoomVersionId::V11 => Self::V11, + | ver => return Err(Error::Unsupported(format!("found version `{ver}`"))), + }) + } +} diff --git a/src/core/state_res/state_event.rs b/src/core/state_res/state_event.rs new file mode 100644 index 00000000..2c038cfe --- /dev/null +++ b/src/core/state_res/state_event.rs @@ -0,0 +1,102 @@ +use std::{ + borrow::Borrow, + fmt::{Debug, Display}, + hash::Hash, + sync::Arc, +}; + +use ruma::{events::TimelineEventType, EventId, MilliSecondsSinceUnixEpoch, RoomId, UserId}; +use serde_json::value::RawValue as RawJsonValue; + +/// Abstraction of a PDU so users can have their own PDU types. +pub trait Event { + type Id: Clone + Debug + Display + Eq + Ord + Hash + Send + Borrow; + + /// The `EventId` of this event. + fn event_id(&self) -> &Self::Id; + + /// The `RoomId` of this event. + fn room_id(&self) -> &RoomId; + + /// The `UserId` of this event. + fn sender(&self) -> &UserId; + + /// The time of creation on the originating server. + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch; + + /// The event type. + fn event_type(&self) -> &TimelineEventType; + + /// The event's content. + fn content(&self) -> &RawJsonValue; + + /// The state key for this event. + fn state_key(&self) -> Option<&str>; + + /// The events before this event. + // Requires GATs to avoid boxing (and TAIT for making it convenient). + fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_; + + /// All the authenticating events for this event. + // Requires GATs to avoid boxing (and TAIT for making it convenient). + fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_; + + /// If this event is a redaction event this is the event it redacts. + fn redacts(&self) -> Option<&Self::Id>; +} + +impl Event for &T { + type Id = T::Id; + + fn event_id(&self) -> &Self::Id { (*self).event_id() } + + fn room_id(&self) -> &RoomId { (*self).room_id() } + + fn sender(&self) -> &UserId { (*self).sender() } + + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { (*self).origin_server_ts() } + + fn event_type(&self) -> &TimelineEventType { (*self).event_type() } + + fn content(&self) -> &RawJsonValue { (*self).content() } + + fn state_key(&self) -> Option<&str> { (*self).state_key() } + + fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_ { + (*self).prev_events() + } + + fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_ { + (*self).auth_events() + } + + fn redacts(&self) -> Option<&Self::Id> { (*self).redacts() } +} + +impl Event for Arc { + type Id = T::Id; + + fn event_id(&self) -> &Self::Id { (**self).event_id() } + + fn room_id(&self) -> &RoomId { (**self).room_id() } + + fn sender(&self) -> &UserId { (**self).sender() } + + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { (**self).origin_server_ts() } + + fn event_type(&self) -> &TimelineEventType { (**self).event_type() } + + fn content(&self) -> &RawJsonValue { (**self).content() } + + fn state_key(&self) -> Option<&str> { (**self).state_key() } + + fn prev_events(&self) -> impl DoubleEndedIterator + Send + '_ { + (**self).prev_events() + } + + fn auth_events(&self) -> impl DoubleEndedIterator + Send + '_ { + (**self).auth_events() + } + + fn redacts(&self) -> Option<&Self::Id> { (**self).redacts() } +} diff --git a/src/core/state_res/state_res_bench.rs b/src/core/state_res/state_res_bench.rs new file mode 100644 index 00000000..a2bd2c23 --- /dev/null +++ b/src/core/state_res/state_res_bench.rs @@ -0,0 +1,648 @@ +// Because of criterion `cargo bench` works, +// but if you use `cargo bench -- --save-baseline ` +// or pass any other args to it, it fails with the error +// `cargo bench unknown option --save-baseline`. +// To pass args to criterion, use this form +// `cargo bench --bench -- --save-baseline `. + +#![allow(clippy::exhaustive_structs)] + +use std::{ + borrow::Borrow, + collections::{HashMap, HashSet}, + sync::{ + atomic::{AtomicU64, Ordering::SeqCst}, + Arc, + }, +}; + +use criterion::{criterion_group, criterion_main, Criterion}; +use event::PduEvent; +use futures::{future, future::ready}; +use ruma::{int, uint}; +use maplit::{btreemap, hashmap, hashset}; +use ruma::{ + room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, RoomVersionId, + Signatures, UserId, +}; +use ruma::events::{ + pdu::{EventHash, Pdu, RoomV3Pdu}, + room::{ + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + }, + StateEventType, TimelineEventType, +}; +use conduwuit::state_res::{self as state_res, Error, Event, Result, StateMap}; +use serde_json::{ + json, + value::{to_raw_value as to_raw_json_value, RawValue as RawJsonValue}, +}; + +static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0); + +fn lexico_topo_sort(c: &mut Criterion) { + c.bench_function("lexicographical topological sort", |b| { + let graph = hashmap! { + event_id("l") => hashset![event_id("o")], + event_id("m") => hashset![event_id("n"), event_id("o")], + event_id("n") => hashset![event_id("o")], + event_id("o") => hashset![], // "o" has zero outgoing edges but 4 incoming edges + event_id("p") => hashset![event_id("o")], + }; + b.iter(|| { + let _ = state_res::lexicographical_topological_sort(&graph, &|_| { + future::ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) + }); + }); + }); +} + +fn resolution_shallow_auth_chain(c: &mut Criterion) { + c.bench_function("resolve state of 5 events one fork", |b| { + let mut store = TestStore(hashmap! {}); + + // build up the DAG + let (state_at_bob, state_at_charlie, _) = store.set_up(); + + b.iter(|| async { + let ev_map = store.0.clone(); + let state_sets = [&state_at_bob, &state_at_charlie]; + let fetch = |id: OwnedEventId| ready(ev_map.get(&id).map(Arc::clone)); + let exists = |id: OwnedEventId| ready(ev_map.get(&id).is_some()); + let auth_chain_sets = state_sets + .iter() + .map(|map| { + store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap() + }) + .collect(); + + let _ = match state_res::resolve( + &RoomVersionId::V6, + state_sets.into_iter(), + &auth_chain_sets, + &fetch, + &exists, + ) + .await + { + Ok(state) => state, + Err(e) => panic!("{e}"), + }; + }); + }); +} + +fn resolve_deeper_event_set(c: &mut Criterion) { + c.bench_function("resolve state of 10 events 3 conflicting", |b| { + let mut inner = INITIAL_EVENTS(); + let ban = BAN_STATE_SET(); + + inner.extend(ban); + let store = TestStore(inner.clone()); + + let state_set_a = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("MB")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| { + (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id().to_owned()) + }) + .collect::>(); + + let state_set_b = [ + inner.get(&event_id("CREATE")).unwrap(), + inner.get(&event_id("IJR")).unwrap(), + inner.get(&event_id("IMA")).unwrap(), + inner.get(&event_id("IMB")).unwrap(), + inner.get(&event_id("IMC")).unwrap(), + inner.get(&event_id("IME")).unwrap(), + inner.get(&event_id("PA")).unwrap(), + ] + .iter() + .map(|ev| { + (ev.event_type().with_state_key(ev.state_key().unwrap()), ev.event_id().to_owned()) + }) + .collect::>(); + + b.iter(|| async { + let state_sets = [&state_set_a, &state_set_b]; + let auth_chain_sets = state_sets + .iter() + .map(|map| { + store.auth_event_ids(room_id(), map.values().cloned().collect()).unwrap() + }) + .collect(); + + let fetch = |id: OwnedEventId| ready(inner.get(&id).map(Arc::clone)); + let exists = |id: OwnedEventId| ready(inner.get(&id).is_some()); + let _ = match state_res::resolve( + &RoomVersionId::V6, + state_sets.into_iter(), + &auth_chain_sets, + &fetch, + &exists, + ) + .await + { + Ok(state) => state, + Err(_) => panic!("resolution failed during benchmarking"), + }; + }); + }); +} + +criterion_group!( + benches, + lexico_topo_sort, + resolution_shallow_auth_chain, + resolve_deeper_event_set +); + +criterion_main!(benches); + +//*///////////////////////////////////////////////////////////////////// +// +// IMPLEMENTATION DETAILS AHEAD +// +/////////////////////////////////////////////////////////////////////*/ +struct TestStore(HashMap>); + +#[allow(unused)] +impl TestStore { + fn get_event(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + self.0 + .get(event_id) + .map(Arc::clone) + .ok_or_else(|| Error::NotFound(format!("{} not found", event_id))) + } + + /// Returns the events that correspond to the `event_ids` sorted in the same order. + fn get_events(&self, room_id: &RoomId, event_ids: &[OwnedEventId]) -> Result>> { + let mut events = vec![]; + for id in event_ids { + events.push(self.get_event(room_id, id)?); + } + Ok(events) + } + + /// Returns a Vec of the related auth events to the given `event`. + fn auth_event_ids(&self, room_id: &RoomId, event_ids: Vec) -> Result> { + let mut result = HashSet::new(); + let mut stack = event_ids; + + // DFS for auth event chain + while !stack.is_empty() { + let ev_id = stack.pop().unwrap(); + if result.contains(&ev_id) { + continue; + } + + result.insert(ev_id.clone()); + + let event = self.get_event(room_id, ev_id.borrow())?; + + stack.extend(event.auth_events().map(ToOwned::to_owned)); + } + + Ok(result) + } + + /// Returns a vector representing the difference in auth chains of the given `events`. + fn auth_chain_diff(&self, room_id: &RoomId, event_ids: Vec>) -> Result> { + let mut auth_chain_sets = vec![]; + for ids in event_ids { + // TODO state store `auth_event_ids` returns self in the event ids list + // when an event returns `auth_event_ids` self is not contained + let chain = self.auth_event_ids(room_id, ids)?.into_iter().collect::>(); + auth_chain_sets.push(chain); + } + + if let Some(first) = auth_chain_sets.first().cloned() { + let common = auth_chain_sets + .iter() + .skip(1) + .fold(first, |a, b| a.intersection(b).cloned().collect::>()); + + Ok(auth_chain_sets + .into_iter() + .flatten() + .filter(|id| !common.contains(id.borrow())) + .collect()) + } else { + Ok(vec![]) + } + } +} + +impl TestStore { + #[allow(clippy::type_complexity)] + fn set_up( + &mut self, + ) -> (StateMap, StateMap, StateMap) { + let create_event = to_pdu_event::<&EventId>( + "CREATE", + alice(), + TimelineEventType::RoomCreate, + Some(""), + to_raw_json_value(&json!({ "creator": alice() })).unwrap(), + &[], + &[], + ); + let cre = create_event.event_id().to_owned(); + self.0.insert(cre.clone(), Arc::clone(&create_event)); + + let alice_mem = to_pdu_event( + "IMA", + alice(), + TimelineEventType::RoomMember, + Some(alice().to_string().as_str()), + member_content_join(), + &[cre.clone()], + &[cre.clone()], + ); + self.0.insert(alice_mem.event_id().to_owned(), Arc::clone(&alice_mem)); + + let join_rules = to_pdu_event( + "IJR", + alice(), + TimelineEventType::RoomJoinRules, + Some(""), + to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Public)).unwrap(), + &[cre.clone(), alice_mem.event_id().to_owned()], + &[alice_mem.event_id().to_owned()], + ); + self.0.insert(join_rules.event_id().to_owned(), join_rules.clone()); + + // Bob and Charlie join at the same time, so there is a fork + // this will be represented in the state_sets when we resolve + let bob_mem = to_pdu_event( + "IMB", + bob(), + TimelineEventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_join(), + &[cre.clone(), join_rules.event_id().to_owned()], + &[join_rules.event_id().to_owned()], + ); + self.0.insert(bob_mem.event_id().to_owned(), bob_mem.clone()); + + let charlie_mem = to_pdu_event( + "IMC", + charlie(), + TimelineEventType::RoomMember, + Some(charlie().to_string().as_str()), + member_content_join(), + &[cre, join_rules.event_id().to_owned()], + &[join_rules.event_id().to_owned()], + ); + self.0.insert(charlie_mem.event_id().to_owned(), charlie_mem.clone()); + + let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] + .iter() + .map(|e| { + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) + }) + .collect::>(); + + let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] + .iter() + .map(|e| { + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) + }) + .collect::>(); + + let expected = [&create_event, &alice_mem, &join_rules, &bob_mem, &charlie_mem] + .iter() + .map(|e| { + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) + }) + .collect::>(); + + (state_at_bob, state_at_charlie, expected) + } +} + +fn event_id(id: &str) -> OwnedEventId { + if id.contains('$') { + return id.try_into().unwrap(); + } + format!("${}:foo", id).try_into().unwrap() +} + +fn alice() -> &'static UserId { + user_id!("@alice:foo") +} + +fn bob() -> &'static UserId { + user_id!("@bob:foo") +} + +fn charlie() -> &'static UserId { + user_id!("@charlie:foo") +} + +fn ella() -> &'static UserId { + user_id!("@ella:foo") +} + +fn room_id() -> &'static RoomId { + room_id!("!test:foo") +} + +fn member_content_ban() -> Box { + to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Ban)).unwrap() +} + +fn member_content_join() -> Box { + to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Join)).unwrap() +} + +fn to_pdu_event( + id: &str, + sender: &UserId, + ev_type: TimelineEventType, + state_key: Option<&str>, + content: Box, + auth_events: &[S], + prev_events: &[S], +) -> Arc +where + S: AsRef, +{ + // We don't care if the addition happens in order just that it is atomic + // (each event has its own value) + let ts = SERVER_TIMESTAMP.fetch_add(1, SeqCst); + let id = if id.contains('$') { id.to_owned() } else { format!("${}:foo", id) }; + let auth_events = auth_events.iter().map(AsRef::as_ref).map(event_id).collect::>(); + let prev_events = prev_events.iter().map(AsRef::as_ref).map(event_id).collect::>(); + + let state_key = state_key.map(ToOwned::to_owned); + Arc::new(PduEvent { + event_id: id.try_into().unwrap(), + rest: Pdu::RoomV3Pdu(RoomV3Pdu { + room_id: room_id().to_owned(), + sender: sender.to_owned(), + origin_server_ts: MilliSecondsSinceUnixEpoch(ts.try_into().unwrap()), + state_key, + kind: ev_type, + content, + redacts: None, + unsigned: btreemap! {}, + auth_events, + prev_events, + depth: uint!(0), + hashes: EventHash::new(String::new()), + signatures: Signatures::new(), + }), + }) +} + +// all graphs start with these input events +#[allow(non_snake_case)] +fn INITIAL_EVENTS() -> HashMap> { + vec![ + to_pdu_event::<&EventId>( + "CREATE", + alice(), + TimelineEventType::RoomCreate, + Some(""), + to_raw_json_value(&json!({ "creator": alice() })).unwrap(), + &[], + &[], + ), + to_pdu_event( + "IMA", + alice(), + TimelineEventType::RoomMember, + Some(alice().as_str()), + member_content_join(), + &["CREATE"], + &["CREATE"], + ), + to_pdu_event( + "IPOWER", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100 } })).unwrap(), + &["CREATE", "IMA"], + &["IMA"], + ), + to_pdu_event( + "IJR", + alice(), + TimelineEventType::RoomJoinRules, + Some(""), + to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Public)).unwrap(), + &["CREATE", "IMA", "IPOWER"], + &["IPOWER"], + ), + to_pdu_event( + "IMB", + bob(), + TimelineEventType::RoomMember, + Some(bob().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IJR"], + ), + to_pdu_event( + "IMC", + charlie(), + TimelineEventType::RoomMember, + Some(charlie().to_string().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IMB"], + ), + to_pdu_event::<&EventId>( + "START", + charlie(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + &[], + &[], + ), + to_pdu_event::<&EventId>( + "END", + charlie(), + TimelineEventType::RoomTopic, + Some(""), + to_raw_json_value(&json!({})).unwrap(), + &[], + &[], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().to_owned(), ev)) + .collect() +} + +// all graphs start with these input events +#[allow(non_snake_case)] +fn BAN_STATE_SET() -> HashMap> { + vec![ + to_pdu_event( + "PA", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + &["CREATE", "IMA", "IPOWER"], // auth_events + &["START"], // prev_events + ), + to_pdu_event( + "PB", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100, bob(): 50 } })).unwrap(), + &["CREATE", "IMA", "IPOWER"], + &["END"], + ), + to_pdu_event( + "MB", + alice(), + TimelineEventType::RoomMember, + Some(ella().as_str()), + member_content_ban(), + &["CREATE", "IMA", "PB"], + &["PA"], + ), + to_pdu_event( + "IME", + ella(), + TimelineEventType::RoomMember, + Some(ella().as_str()), + member_content_join(), + &["CREATE", "IJR", "PA"], + &["MB"], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().to_owned(), ev)) + .collect() +} + +/// Convenience trait for adding event type plus state key to state maps. +trait EventTypeExt { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String); +} + +impl EventTypeExt for &TimelineEventType { + fn with_state_key(self, state_key: impl Into) -> (StateEventType, String) { + (self.to_string().into(), state_key.into()) + } +} + +mod event { + use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId}; + use ruma_events::{pdu::Pdu, TimelineEventType}; + use ruma_state_res::Event; + use serde::{Deserialize, Serialize}; + use serde_json::value::RawValue as RawJsonValue; + + impl Event for PduEvent { + type Id = OwnedEventId; + + fn event_id(&self) -> &Self::Id { + &self.event_id + } + + fn room_id(&self) -> &RoomId { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.room_id, + Pdu::RoomV3Pdu(ev) => &ev.room_id, + #[cfg(not(feature = "unstable-exhaustive-types"))] + _ => unreachable!("new PDU version"), + } + } + + fn sender(&self) -> &UserId { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.sender, + Pdu::RoomV3Pdu(ev) => &ev.sender, + #[cfg(not(feature = "unstable-exhaustive-types"))] + _ => unreachable!("new PDU version"), + } + } + + fn event_type(&self) -> &TimelineEventType { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.kind, + Pdu::RoomV3Pdu(ev) => &ev.kind, + #[cfg(not(feature = "unstable-exhaustive-types"))] + _ => unreachable!("new PDU version"), + } + } + + fn content(&self) -> &RawJsonValue { + match &self.rest { + Pdu::RoomV1Pdu(ev) => &ev.content, + Pdu::RoomV3Pdu(ev) => &ev.content, + #[cfg(not(feature = "unstable-exhaustive-types"))] + _ => unreachable!("new PDU version"), + } + } + + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.origin_server_ts, + Pdu::RoomV3Pdu(ev) => ev.origin_server_ts, + #[cfg(not(feature = "unstable-exhaustive-types"))] + _ => unreachable!("new PDU version"), + } + } + + fn state_key(&self) -> Option<&str> { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.state_key.as_deref(), + Pdu::RoomV3Pdu(ev) => ev.state_key.as_deref(), + #[cfg(not(feature = "unstable-exhaustive-types"))] + _ => unreachable!("new PDU version"), + } + } + + fn prev_events(&self) -> Box + Send + '_> { + match &self.rest { + Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| id)), + Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter()), + #[cfg(not(feature = "unstable-exhaustive-types"))] + _ => unreachable!("new PDU version"), + } + } + + fn auth_events(&self) -> Box + Send + '_> { + match &self.rest { + Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| id)), + Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter()), + #[cfg(not(feature = "unstable-exhaustive-types"))] + _ => unreachable!("new PDU version"), + } + } + + fn redacts(&self) -> Option<&Self::Id> { + match &self.rest { + Pdu::RoomV1Pdu(ev) => ev.redacts.as_ref(), + Pdu::RoomV3Pdu(ev) => ev.redacts.as_ref(), + #[cfg(not(feature = "unstable-exhaustive-types"))] + _ => unreachable!("new PDU version"), + } + } + } + + #[derive(Clone, Debug, Deserialize, Serialize)] + pub(crate) struct PduEvent { + pub(crate) event_id: OwnedEventId, + #[serde(flatten)] + pub(crate) rest: Pdu, + } +} diff --git a/src/core/state_res/test_utils.rs b/src/core/state_res/test_utils.rs new file mode 100644 index 00000000..7954b28d --- /dev/null +++ b/src/core/state_res/test_utils.rs @@ -0,0 +1,688 @@ +use std::{ + borrow::Borrow, + collections::{BTreeMap, HashMap, HashSet}, + sync::{ + atomic::{AtomicU64, Ordering::SeqCst}, + Arc, + }, +}; + +use futures_util::future::ready; +use js_int::{int, uint}; +use ruma_common::{ + event_id, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, + RoomVersionId, ServerSignatures, UserId, +}; +use ruma_events::{ + pdu::{EventHash, Pdu, RoomV3Pdu}, + room::{ + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + }, + TimelineEventType, +}; +use serde_json::{ + json, + value::{to_raw_value as to_raw_json_value, RawValue as RawJsonValue}, +}; +use tracing::info; + +pub(crate) use self::event::PduEvent; +use crate::{auth_types_for_event, Error, Event, EventTypeExt, Result, StateMap}; + +static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0); + +pub(crate) async fn do_check( + events: &[Arc], + edges: Vec>, + expected_state_ids: Vec, +) { + // To activate logging use `RUST_LOG=debug cargo t` + + let init_events = INITIAL_EVENTS(); + + let mut store = TestStore( + init_events + .values() + .chain(events) + .map(|ev| (ev.event_id().to_owned(), ev.clone())) + .collect(), + ); + + // This will be lexi_topo_sorted for resolution + let mut graph = HashMap::new(); + // This is the same as in `resolve` event_id -> OriginalStateEvent + let mut fake_event_map = HashMap::new(); + + // Create the DB of events that led up to this point + // TODO maybe clean up some of these clones it is just tests but... + for ev in init_events.values().chain(events) { + graph.insert(ev.event_id().to_owned(), HashSet::new()); + fake_event_map.insert(ev.event_id().to_owned(), ev.clone()); + } + + for pair in INITIAL_EDGES().windows(2) { + if let [a, b] = &pair { + graph + .entry(a.to_owned()) + .or_insert_with(HashSet::new) + .insert(b.clone()); + } + } + + for edge_list in edges { + for pair in edge_list.windows(2) { + if let [a, b] = &pair { + graph + .entry(a.to_owned()) + .or_insert_with(HashSet::new) + .insert(b.clone()); + } + } + } + + // event_id -> PduEvent + let mut event_map: HashMap> = HashMap::new(); + // event_id -> StateMap + let mut state_at_event: HashMap> = HashMap::new(); + + // Resolve the current state and add it to the state_at_event map then continue + // on in "time" + for node in crate::lexicographical_topological_sort(&graph, &|_id| async { + Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) + }) + .await + .unwrap() + { + let fake_event = fake_event_map.get(&node).unwrap(); + let event_id = fake_event.event_id().to_owned(); + + let prev_events = graph.get(&node).unwrap(); + + let state_before: StateMap = if prev_events.is_empty() { + HashMap::new() + } else if prev_events.len() == 1 { + state_at_event + .get(prev_events.iter().next().unwrap()) + .unwrap() + .clone() + } else { + let state_sets = prev_events + .iter() + .filter_map(|k| state_at_event.get(k)) + .collect::>(); + + info!( + "{:#?}", + state_sets + .iter() + .map(|map| map + .iter() + .map(|((ty, key), id)| format!("(({ty}{key:?}), {id})")) + .collect::>()) + .collect::>() + ); + + let auth_chain_sets: Vec<_> = state_sets + .iter() + .map(|map| { + store + .auth_event_ids(room_id(), map.values().cloned().collect()) + .unwrap() + }) + .collect(); + + let event_map = &event_map; + let fetch = |id: ::Id| ready(event_map.get(&id).cloned()); + let exists = |id: ::Id| ready(event_map.get(&id).is_some()); + let resolved = crate::resolve( + &RoomVersionId::V6, + state_sets, + &auth_chain_sets, + &fetch, + &exists, + 1, + ) + .await; + + match resolved { + | Ok(state) => state, + | Err(e) => panic!("resolution for {node} failed: {e}"), + } + }; + + let mut state_after = state_before.clone(); + + let ty = fake_event.event_type(); + let key = fake_event.state_key().unwrap(); + state_after.insert(ty.with_state_key(key), event_id.to_owned()); + + let auth_types = auth_types_for_event( + fake_event.event_type(), + fake_event.sender(), + fake_event.state_key(), + fake_event.content(), + ) + .unwrap(); + + let mut auth_events = vec![]; + for key in auth_types { + if state_before.contains_key(&key) { + auth_events.push(state_before[&key].clone()); + } + } + + // TODO The event is just remade, adding the auth_events and prev_events here + // the `to_pdu_event` was split into `init` and the fn below, could be better + let e = fake_event; + let ev_id = e.event_id(); + let event = to_pdu_event( + e.event_id().as_str(), + e.sender(), + e.event_type().clone(), + e.state_key(), + e.content().to_owned(), + &auth_events, + &prev_events.iter().cloned().collect::>(), + ); + + // We have to update our store, an actual user of this lib would + // be giving us state from a DB. + store.0.insert(ev_id.to_owned(), event.clone()); + + state_at_event.insert(node, state_after); + event_map.insert(event_id.to_owned(), Arc::clone(store.0.get(ev_id).unwrap())); + } + + let mut expected_state = StateMap::new(); + for node in expected_state_ids { + let ev = event_map.get(&node).unwrap_or_else(|| { + panic!( + "{node} not found in {:?}", + event_map + .keys() + .map(ToString::to_string) + .collect::>() + ) + }); + + let key = ev.event_type().with_state_key(ev.state_key().unwrap()); + + expected_state.insert(key, node); + } + + let start_state = state_at_event.get(event_id!("$START:foo")).unwrap(); + + let end_state = state_at_event + .get(event_id!("$END:foo")) + .unwrap() + .iter() + .filter(|(k, v)| { + expected_state.contains_key(k) + || start_state.get(k) != Some(*v) + // Filter out the dummy messages events. + // These act as points in time where there should be a known state to + // test against. + && **k != ("m.room.message".into(), "dummy".to_owned()) + }) + .map(|(k, v)| (k.clone(), v.clone())) + .collect::>(); + + assert_eq!(expected_state, end_state); +} + +#[allow(clippy::exhaustive_structs)] +pub(crate) struct TestStore(pub(crate) HashMap>); + +impl TestStore { + pub(crate) fn get_event(&self, _: &RoomId, event_id: &EventId) -> Result> { + self.0 + .get(event_id) + .cloned() + .ok_or_else(|| Error::NotFound(format!("{event_id} not found"))) + } + + /// Returns a Vec of the related auth events to the given `event`. + pub(crate) fn auth_event_ids( + &self, + room_id: &RoomId, + event_ids: Vec, + ) -> Result> { + let mut result = HashSet::new(); + let mut stack = event_ids; + + // DFS for auth event chain + while let Some(ev_id) = stack.pop() { + if result.contains(&ev_id) { + continue; + } + + result.insert(ev_id.clone()); + + let event = self.get_event(room_id, ev_id.borrow())?; + + stack.extend(event.auth_events().map(ToOwned::to_owned)); + } + + Ok(result) + } +} + +// A StateStore implementation for testing +#[allow(clippy::type_complexity)] +impl TestStore { + pub(crate) fn set_up( + &mut self, + ) -> (StateMap, StateMap, StateMap) { + let create_event = to_pdu_event::<&EventId>( + "CREATE", + alice(), + TimelineEventType::RoomCreate, + Some(""), + to_raw_json_value(&json!({ "creator": alice() })).unwrap(), + &[], + &[], + ); + let cre = create_event.event_id().to_owned(); + self.0.insert(cre.clone(), Arc::clone(&create_event)); + + let alice_mem = to_pdu_event( + "IMA", + alice(), + TimelineEventType::RoomMember, + Some(alice().as_str()), + member_content_join(), + &[cre.clone()], + &[cre.clone()], + ); + self.0 + .insert(alice_mem.event_id().to_owned(), Arc::clone(&alice_mem)); + + let join_rules = to_pdu_event( + "IJR", + alice(), + TimelineEventType::RoomJoinRules, + Some(""), + to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Public)).unwrap(), + &[cre.clone(), alice_mem.event_id().to_owned()], + &[alice_mem.event_id().to_owned()], + ); + self.0 + .insert(join_rules.event_id().to_owned(), join_rules.clone()); + + // Bob and Charlie join at the same time, so there is a fork + // this will be represented in the state_sets when we resolve + let bob_mem = to_pdu_event( + "IMB", + bob(), + TimelineEventType::RoomMember, + Some(bob().as_str()), + member_content_join(), + &[cre.clone(), join_rules.event_id().to_owned()], + &[join_rules.event_id().to_owned()], + ); + self.0 + .insert(bob_mem.event_id().to_owned(), bob_mem.clone()); + + let charlie_mem = to_pdu_event( + "IMC", + charlie(), + TimelineEventType::RoomMember, + Some(charlie().as_str()), + member_content_join(), + &[cre, join_rules.event_id().to_owned()], + &[join_rules.event_id().to_owned()], + ); + self.0 + .insert(charlie_mem.event_id().to_owned(), charlie_mem.clone()); + + let state_at_bob = [&create_event, &alice_mem, &join_rules, &bob_mem] + .iter() + .map(|e| { + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) + }) + .collect::>(); + + let state_at_charlie = [&create_event, &alice_mem, &join_rules, &charlie_mem] + .iter() + .map(|e| { + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) + }) + .collect::>(); + + let expected = [&create_event, &alice_mem, &join_rules, &bob_mem, &charlie_mem] + .iter() + .map(|e| { + (e.event_type().with_state_key(e.state_key().unwrap()), e.event_id().to_owned()) + }) + .collect::>(); + + (state_at_bob, state_at_charlie, expected) + } +} + +pub(crate) fn event_id(id: &str) -> OwnedEventId { + if id.contains('$') { + return id.try_into().unwrap(); + } + + format!("${id}:foo").try_into().unwrap() +} + +pub(crate) fn alice() -> &'static UserId { user_id!("@alice:foo") } + +pub(crate) fn bob() -> &'static UserId { user_id!("@bob:foo") } + +pub(crate) fn charlie() -> &'static UserId { user_id!("@charlie:foo") } + +pub(crate) fn ella() -> &'static UserId { user_id!("@ella:foo") } + +pub(crate) fn zara() -> &'static UserId { user_id!("@zara:foo") } + +pub(crate) fn room_id() -> &'static RoomId { room_id!("!test:foo") } + +pub(crate) fn member_content_ban() -> Box { + to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Ban)).unwrap() +} + +pub(crate) fn member_content_join() -> Box { + to_raw_json_value(&RoomMemberEventContent::new(MembershipState::Join)).unwrap() +} + +pub(crate) fn to_init_pdu_event( + id: &str, + sender: &UserId, + ev_type: TimelineEventType, + state_key: Option<&str>, + content: Box, +) -> Arc { + let ts = SERVER_TIMESTAMP.fetch_add(1, SeqCst); + let id = if id.contains('$') { + id.to_owned() + } else { + format!("${id}:foo") + }; + + let state_key = state_key.map(ToOwned::to_owned); + Arc::new(PduEvent { + event_id: id.try_into().unwrap(), + rest: Pdu::RoomV3Pdu(RoomV3Pdu { + room_id: room_id().to_owned(), + sender: sender.to_owned(), + origin_server_ts: MilliSecondsSinceUnixEpoch(ts.try_into().unwrap()), + state_key, + kind: ev_type, + content, + redacts: None, + unsigned: BTreeMap::new(), + auth_events: vec![], + prev_events: vec![], + depth: uint!(0), + hashes: EventHash::new("".to_owned()), + signatures: ServerSignatures::default(), + }), + }) +} + +pub(crate) fn to_pdu_event( + id: &str, + sender: &UserId, + ev_type: TimelineEventType, + state_key: Option<&str>, + content: Box, + auth_events: &[S], + prev_events: &[S], +) -> Arc +where + S: AsRef, +{ + let ts = SERVER_TIMESTAMP.fetch_add(1, SeqCst); + let id = if id.contains('$') { + id.to_owned() + } else { + format!("${id}:foo") + }; + let auth_events = auth_events + .iter() + .map(AsRef::as_ref) + .map(event_id) + .collect::>(); + let prev_events = prev_events + .iter() + .map(AsRef::as_ref) + .map(event_id) + .collect::>(); + + let state_key = state_key.map(ToOwned::to_owned); + Arc::new(PduEvent { + event_id: id.try_into().unwrap(), + rest: Pdu::RoomV3Pdu(RoomV3Pdu { + room_id: room_id().to_owned(), + sender: sender.to_owned(), + origin_server_ts: MilliSecondsSinceUnixEpoch(ts.try_into().unwrap()), + state_key, + kind: ev_type, + content, + redacts: None, + unsigned: BTreeMap::new(), + auth_events, + prev_events, + depth: uint!(0), + hashes: EventHash::new("".to_owned()), + signatures: ServerSignatures::default(), + }), + }) +} + +// all graphs start with these input events +#[allow(non_snake_case)] +pub(crate) fn INITIAL_EVENTS() -> HashMap> { + vec![ + to_pdu_event::<&EventId>( + "CREATE", + alice(), + TimelineEventType::RoomCreate, + Some(""), + to_raw_json_value(&json!({ "creator": alice() })).unwrap(), + &[], + &[], + ), + to_pdu_event( + "IMA", + alice(), + TimelineEventType::RoomMember, + Some(alice().as_str()), + member_content_join(), + &["CREATE"], + &["CREATE"], + ), + to_pdu_event( + "IPOWER", + alice(), + TimelineEventType::RoomPowerLevels, + Some(""), + to_raw_json_value(&json!({ "users": { alice(): 100 } })).unwrap(), + &["CREATE", "IMA"], + &["IMA"], + ), + to_pdu_event( + "IJR", + alice(), + TimelineEventType::RoomJoinRules, + Some(""), + to_raw_json_value(&RoomJoinRulesEventContent::new(JoinRule::Public)).unwrap(), + &["CREATE", "IMA", "IPOWER"], + &["IPOWER"], + ), + to_pdu_event( + "IMB", + bob(), + TimelineEventType::RoomMember, + Some(bob().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IJR"], + ), + to_pdu_event( + "IMC", + charlie(), + TimelineEventType::RoomMember, + Some(charlie().as_str()), + member_content_join(), + &["CREATE", "IJR", "IPOWER"], + &["IMB"], + ), + to_pdu_event::<&EventId>( + "START", + charlie(), + TimelineEventType::RoomMessage, + Some("dummy"), + to_raw_json_value(&json!({})).unwrap(), + &[], + &[], + ), + to_pdu_event::<&EventId>( + "END", + charlie(), + TimelineEventType::RoomMessage, + Some("dummy"), + to_raw_json_value(&json!({})).unwrap(), + &[], + &[], + ), + ] + .into_iter() + .map(|ev| (ev.event_id().to_owned(), ev)) + .collect() +} + +// all graphs start with these input events +#[allow(non_snake_case)] +pub(crate) fn INITIAL_EVENTS_CREATE_ROOM() -> HashMap> { + vec![to_pdu_event::<&EventId>( + "CREATE", + alice(), + TimelineEventType::RoomCreate, + Some(""), + to_raw_json_value(&json!({ "creator": alice() })).unwrap(), + &[], + &[], + )] + .into_iter() + .map(|ev| (ev.event_id().to_owned(), ev)) + .collect() +} + +#[allow(non_snake_case)] +pub(crate) fn INITIAL_EDGES() -> Vec { + vec!["START", "IMC", "IMB", "IJR", "IPOWER", "IMA", "CREATE"] + .into_iter() + .map(event_id) + .collect::>() +} + +pub(crate) mod event { + use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId}; + use ruma_events::{pdu::Pdu, TimelineEventType}; + use serde::{Deserialize, Serialize}; + use serde_json::value::RawValue as RawJsonValue; + + use crate::Event; + + impl Event for PduEvent { + type Id = OwnedEventId; + + fn event_id(&self) -> &Self::Id { &self.event_id } + + fn room_id(&self) -> &RoomId { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.room_id, + | Pdu::RoomV3Pdu(ev) => &ev.room_id, + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn sender(&self) -> &UserId { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.sender, + | Pdu::RoomV3Pdu(ev) => &ev.sender, + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn event_type(&self) -> &TimelineEventType { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.kind, + | Pdu::RoomV3Pdu(ev) => &ev.kind, + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn content(&self) -> &RawJsonValue { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => &ev.content, + | Pdu::RoomV3Pdu(ev) => &ev.content, + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => ev.origin_server_ts, + | Pdu::RoomV3Pdu(ev) => ev.origin_server_ts, + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn state_key(&self) -> Option<&str> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => ev.state_key.as_deref(), + | Pdu::RoomV3Pdu(ev) => ev.state_key.as_deref(), + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + #[allow(refining_impl_trait)] + fn prev_events(&self) -> Box + Send + '_> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => Box::new(ev.prev_events.iter().map(|(id, _)| id)), + | Pdu::RoomV3Pdu(ev) => Box::new(ev.prev_events.iter()), + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + #[allow(refining_impl_trait)] + fn auth_events(&self) -> Box + Send + '_> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => Box::new(ev.auth_events.iter().map(|(id, _)| id)), + | Pdu::RoomV3Pdu(ev) => Box::new(ev.auth_events.iter()), + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + + fn redacts(&self) -> Option<&Self::Id> { + match &self.rest { + | Pdu::RoomV1Pdu(ev) => ev.redacts.as_ref(), + | Pdu::RoomV3Pdu(ev) => ev.redacts.as_ref(), + #[allow(unreachable_patterns)] + | _ => unreachable!("new PDU version"), + } + } + } + + #[derive(Clone, Debug, Deserialize, Serialize)] + #[allow(clippy::exhaustive_structs)] + pub(crate) struct PduEvent { + pub(crate) event_id: OwnedEventId, + #[serde(flatten)] + pub(crate) rest: Pdu, + } +} diff --git a/src/service/rooms/event_handler/fetch_prev.rs b/src/service/rooms/event_handler/fetch_prev.rs index aea70739..5a38f7fe 100644 --- a/src/service/rooms/event_handler/fetch_prev.rs +++ b/src/service/rooms/event_handler/fetch_prev.rs @@ -3,12 +3,15 @@ use std::{ sync::Arc, }; -use conduwuit::{debug_warn, err, implement, PduEvent, Result}; +use conduwuit::{ + debug_warn, err, implement, + state_res::{self}, + PduEvent, Result, +}; use futures::{future, FutureExt}; use ruma::{ - int, - state_res::{self}, - uint, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, ServerName, UInt, + int, uint, CanonicalJsonValue, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, ServerName, + UInt, }; use super::check_room_id; diff --git a/src/service/rooms/event_handler/handle_outlier_pdu.rs b/src/service/rooms/event_handler/handle_outlier_pdu.rs index b7c38313..3cc15fc4 100644 --- a/src/service/rooms/event_handler/handle_outlier_pdu.rs +++ b/src/service/rooms/event_handler/handle_outlier_pdu.rs @@ -3,10 +3,12 @@ use std::{ sync::Arc, }; -use conduwuit::{debug, debug_info, err, implement, trace, warn, Err, Error, PduEvent, Result}; +use conduwuit::{ + debug, debug_info, err, implement, state_res, trace, warn, Err, Error, PduEvent, Result, +}; use futures::{future::ready, TryFutureExt}; use ruma::{ - api::client::error::ErrorKind, events::StateEventType, state_res, CanonicalJsonObject, + api::client::error::ErrorKind, events::StateEventType, CanonicalJsonObject, CanonicalJsonValue, EventId, RoomId, ServerName, }; diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 8bcbc48b..5960c734 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -19,12 +19,12 @@ use std::{ use conduwuit::{ utils::{MutexMap, TryFutureExtExt}, - Err, PduEvent, Result, Server, + Err, PduEvent, Result, RoomVersion, Server, }; use futures::TryFutureExt; use ruma::{ - events::room::create::RoomCreateEventContent, state_res::RoomVersion, OwnedEventId, - OwnedRoomId, RoomId, RoomVersionId, + events::room::create::RoomCreateEventContent, OwnedEventId, OwnedRoomId, RoomId, + RoomVersionId, }; use crate::{globals, rooms, sending, server_keys, Dep}; diff --git a/src/service/rooms/event_handler/resolve_state.rs b/src/service/rooms/event_handler/resolve_state.rs index eb9ca01f..28011a1b 100644 --- a/src/service/rooms/event_handler/resolve_state.rs +++ b/src/service/rooms/event_handler/resolve_state.rs @@ -5,15 +5,14 @@ use std::{ }; use conduwuit::{ - err, implement, trace, + err, implement, + state_res::{self, StateMap}, + trace, utils::stream::{automatic_width, IterStream, ReadyExt, TryWidebandExt, WidebandExt}, Error, Result, }; use futures::{future::try_join, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; -use ruma::{ - state_res::{self, StateMap}, - OwnedEventId, RoomId, RoomVersionId, -}; +use ruma::{OwnedEventId, RoomId, RoomVersionId}; use crate::rooms::state_compressor::CompressedState; diff --git a/src/service/rooms/event_handler/state_at_incoming.rs b/src/service/rooms/event_handler/state_at_incoming.rs index 7bf3b8f8..843b2af9 100644 --- a/src/service/rooms/event_handler/state_at_incoming.rs +++ b/src/service/rooms/event_handler/state_at_incoming.rs @@ -8,10 +8,10 @@ use std::{ use conduwuit::{ debug, err, implement, trace, utils::stream::{BroadbandExt, IterStream, ReadyExt, TryBroadbandExt, TryWidebandExt}, - PduEvent, Result, + PduEvent, Result, StateMap, }; use futures::{future::try_join, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; -use ruma::{state_res::StateMap, OwnedEventId, RoomId, RoomVersionId}; +use ruma::{OwnedEventId, RoomId, RoomVersionId}; use crate::rooms::short::ShortStateHash; diff --git a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs index b33b0388..f319ba48 100644 --- a/src/service/rooms/event_handler/upgrade_outlier_pdu.rs +++ b/src/service/rooms/event_handler/upgrade_outlier_pdu.rs @@ -1,16 +1,12 @@ use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::Instant}; use conduwuit::{ - debug, debug_info, err, implement, trace, + debug, debug_info, err, implement, state_res, trace, utils::stream::{BroadbandExt, ReadyExt}, - warn, Err, PduEvent, Result, + warn, Err, EventTypeExt, PduEvent, Result, }; use futures::{future::ready, FutureExt, StreamExt}; -use ruma::{ - events::StateEventType, - state_res::{self, EventTypeExt}, - CanonicalJsonValue, RoomId, ServerName, -}; +use ruma::{events::StateEventType, CanonicalJsonValue, RoomId, ServerName}; use super::{get_room_version_id, to_room_version}; use crate::rooms::{ diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index de90a89c..d538de3c 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -3,6 +3,7 @@ use std::{collections::HashMap, fmt::Write, iter::once, sync::Arc}; use conduwuit::{ err, result::FlatOk, + state_res::{self, StateMap}, utils::{ calculate_hash, stream::{BroadbandExt, TryIgnore}, @@ -20,7 +21,6 @@ use ruma::{ AnyStrippedStateEvent, StateEventType, TimelineEventType, }, serde::Raw, - state_res::{self, StateMap}, EventId, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId, UserId, }; diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index a7edd4a4..d6154121 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -12,6 +12,7 @@ use std::{ use conduwuit::{ at, debug, debug_warn, err, error, implement, info, pdu::{gen_event_id, EventHash, PduBuilder, PduCount, PduEvent}, + state_res::{self, Event, RoomVersion}, utils::{ self, future::TryExtExt, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt, }, @@ -36,7 +37,6 @@ use ruma::{ GlobalAccountDataEventType, StateEventType, TimelineEventType, }, push::{Action, Ruleset, Tweak}, - state_res::{self, Event, RoomVersion}, uint, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, RoomVersionId, ServerName, UserId, };