optimize further into state-res with SmallString

triage and de-lints for state-res.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-02-08 01:58:13 +00:00 committed by strawberry
parent 0a9a9b3c92
commit f2ca670c3b
15 changed files with 192 additions and 145 deletions

1
Cargo.lock generated
View file

@ -810,6 +810,7 @@ dependencies = [
"libc",
"libloading",
"log",
"maplit",
"nix",
"num-traits",
"rand",

View file

@ -379,6 +379,7 @@ features = [
"unstable-msc4203", # sending to-device events to appservices
"unstable-msc4210", # remove legacy mentions
"unstable-extensible-events",
"unstable-pdu",
]
[workspace.dependencies.rust-rocksdb]
@ -527,6 +528,9 @@ features = ["std"]
version = "0.3.2"
features = ["std"]
[workspace.dependencies.maplit]
version = "1.0.2"
#
# Patches
#

View file

@ -14,7 +14,7 @@ use conduwuit::{
result::FlatOk,
state_res, trace,
utils::{self, shuffle, IterStream, ReadyExt},
warn, Err, PduEvent, Result,
warn, Err, PduEvent, Result, StateKey,
};
use futures::{join, FutureExt, StreamExt, TryFutureExt};
use ruma::{
@ -1151,8 +1151,8 @@ async fn join_room_by_id_helper_remote(
debug!("Running send_join auth check");
let fetch_state = &state;
let state_fetch = |k: &'static StateEventType, s: String| async move {
let shortstatekey = services.rooms.short.get_shortstatekey(k, &s).await.ok()?;
let state_fetch = |k: StateEventType, s: StateKey| async move {
let shortstatekey = services.rooms.short.get_shortstatekey(&k, &s).await.ok()?;
let event_id = fetch_state.get(&shortstatekey)?;
services.rooms.timeline.get_pdu(event_id).await.ok()
@ -1162,7 +1162,7 @@ async fn join_room_by_id_helper_remote(
&state_res::RoomVersion::new(&room_version_id)?,
&parsed_join_pdu,
None, // TODO: third party invite
|k, s| state_fetch(k, s.to_owned()),
|k, s| state_fetch(k.clone(), s.into()),
)
.await
.map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?;

View file

@ -395,9 +395,12 @@ pub(crate) async fn sync_events_v4_route(
.map_or(10, usize_from_u64_truncated)
.min(100);
todo_room
.0
.extend(list.room_details.required_state.iter().cloned());
todo_room.0.extend(
list.room_details
.required_state
.iter()
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
);
todo_room.1 = todo_room.1.max(limit);
// 0 means unknown because it got out of date
@ -449,7 +452,11 @@ pub(crate) async fn sync_events_v4_route(
.map_or(10, usize_from_u64_truncated)
.min(100);
todo_room.0.extend(room.required_state.iter().cloned());
todo_room.0.extend(
room.required_state
.iter()
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
);
todo_room.1 = todo_room.1.max(limit);
// 0 means unknown because it got out of date
todo_room.2 = todo_room.2.min(

View file

@ -223,7 +223,11 @@ async fn fetch_subscriptions(
let limit: UInt = room.timeline_limit;
todo_room.0.extend(room.required_state.iter().cloned());
todo_room.0.extend(
room.required_state
.iter()
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
);
todo_room.1 = todo_room.1.max(usize_from_ruma(limit));
// 0 means unknown because it got out of date
todo_room.2 = todo_room.2.min(
@ -303,9 +307,12 @@ async fn handle_lists<'a>(
let limit: usize = usize_from_ruma(list.room_details.timeline_limit).min(100);
todo_room
.0
.extend(list.room_details.required_state.iter().cloned());
todo_room.0.extend(
list.room_details
.required_state
.iter()
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
);
todo_room.1 = todo_room.1.max(limit);
// 0 means unknown because it got out of date

View file

@ -116,5 +116,8 @@ nix.workspace = true
hardened_malloc-rs.workspace = true
hardened_malloc-rs.optional = true
[dev-dependencies]
maplit.workspace = true
[lints]
workspace = true

View file

@ -21,7 +21,6 @@ use serde::{
Deserialize,
};
use serde_json::{from_str as from_json_str, value::RawValue as RawJsonValue};
use tracing::{debug, error, instrument, trace, warn};
use super::{
power_levels::{
@ -29,8 +28,9 @@ use super::{
deserialize_power_levels_content_invite, deserialize_power_levels_content_redact,
},
room_version::RoomVersion,
Error, Event, Result, StateEventType, TimelineEventType,
Error, Event, Result, StateEventType, StateKey, TimelineEventType,
};
use crate::{debug, error, trace, warn};
// FIXME: field extracting could be bundled for `content`
#[derive(Deserialize)]
@ -56,15 +56,15 @@ pub fn auth_types_for_event(
sender: &UserId,
state_key: Option<&str>,
content: &RawJsonValue,
) -> serde_json::Result<Vec<(StateEventType, String)>> {
) -> serde_json::Result<Vec<(StateEventType, StateKey)>> {
if kind == &TimelineEventType::RoomCreate {
return Ok(vec![]);
}
let mut auth_types = vec![
(StateEventType::RoomPowerLevels, String::new()),
(StateEventType::RoomMember, sender.to_string()),
(StateEventType::RoomCreate, String::new()),
(StateEventType::RoomPowerLevels, StateKey::new()),
(StateEventType::RoomMember, sender.as_str().into()),
(StateEventType::RoomCreate, StateKey::new()),
];
if kind == &TimelineEventType::RoomMember {
@ -82,7 +82,7 @@ pub fn auth_types_for_event(
if [MembershipState::Join, MembershipState::Invite, MembershipState::Knock]
.contains(&membership)
{
let key = (StateEventType::RoomJoinRules, String::new());
let key = (StateEventType::RoomJoinRules, StateKey::new());
if !auth_types.contains(&key) {
auth_types.push(key);
}
@ -91,21 +91,22 @@ pub fn auth_types_for_event(
.join_authorised_via_users_server
.map(|m| m.deserialize())
{
let key = (StateEventType::RoomMember, u.to_string());
let key = (StateEventType::RoomMember, u.as_str().into());
if !auth_types.contains(&key) {
auth_types.push(key);
}
}
}
let key = (StateEventType::RoomMember, state_key.to_owned());
let key = (StateEventType::RoomMember, state_key.into());
if !auth_types.contains(&key) {
auth_types.push(key);
}
if membership == MembershipState::Invite {
if let Some(Ok(t_id)) = content.third_party_invite.map(|t| t.deserialize()) {
let key = (StateEventType::RoomThirdPartyInvite, t_id.signed.token);
let key =
(StateEventType::RoomThirdPartyInvite, t_id.signed.token.into());
if !auth_types.contains(&key) {
auth_types.push(key);
}
@ -128,7 +129,13 @@ pub fn auth_types_for_event(
/// The `fetch_state` closure should gather state from a state snapshot. We need
/// to know if the event passes auth against some state not a recursive
/// collection of auth_events fields.
#[instrument(level = "debug", skip_all, fields(event_id = incoming_event.event_id().borrow().as_str()))]
#[tracing::instrument(
level = "debug",
skip_all,
fields(
event_id = incoming_event.event_id().borrow().as_str()
)
)]
pub async fn auth_check<F, Fut, Fetched, Incoming>(
room_version: &RoomVersion,
incoming_event: &Incoming,
@ -136,10 +143,10 @@ pub async fn auth_check<F, Fut, Fetched, Incoming>(
fetch_state: F,
) -> Result<bool, Error>
where
F: Fn(&'static StateEventType, &str) -> Fut,
F: Fn(&StateEventType, &str) -> Fut + Send,
Fut: Future<Output = Option<Fetched>> + Send,
Fetched: Event + Send,
Incoming: Event + Send,
Incoming: Event + Send + Sync,
{
debug!(
"auth_check beginning for {} ({})",
@ -262,6 +269,7 @@ where
// sender domain of the event does not match the sender domain of the create
// event, reject.
#[derive(Deserialize)]
#[allow(clippy::items_after_statements)]
struct RoomCreateContentFederate {
#[serde(rename = "m.federate", default = "ruma::serde::default_true")]
federate: bool,
@ -354,7 +362,7 @@ where
join_rules_event.as_ref(),
user_for_join_auth.as_deref(),
&user_for_join_auth_membership,
room_create_event,
&room_create_event,
)? {
return Ok(false);
}
@ -364,6 +372,7 @@ where
}
// If the sender's current membership state is not join, reject
#[allow(clippy::manual_let_else)]
let sender_member_event = match sender_member_event {
| Some(mem) => mem,
| None => {
@ -498,19 +507,20 @@ where
/// This is generated by calling `auth_types_for_event` with the membership
/// event and the current State.
#[allow(clippy::too_many_arguments)]
#[allow(clippy::cognitive_complexity)]
fn valid_membership_change(
room_version: &RoomVersion,
target_user: &UserId,
target_user_membership_event: Option<impl Event>,
target_user_membership_event: Option<&impl Event>,
sender: &UserId,
sender_membership_event: Option<impl Event>,
sender_membership_event: Option<&impl Event>,
current_event: impl Event,
current_third_party_invite: Option<impl Event>,
power_levels_event: Option<impl Event>,
join_rules_event: Option<impl Event>,
current_third_party_invite: Option<&impl Event>,
power_levels_event: Option<&impl Event>,
join_rules_event: Option<&impl Event>,
user_for_join_auth: Option<&UserId>,
user_for_join_auth_membership: &MembershipState,
create_room: impl Event,
create_room: &impl Event,
) -> Result<bool> {
#[derive(Deserialize)]
struct GetThirdPartyInvite {
@ -856,6 +866,7 @@ fn check_power_levels(
// and integers here
debug!("validation of power event finished");
#[allow(clippy::manual_let_else)]
let current_state = match previous_power_event {
| Some(current_state) => current_state,
// If there is no previous m.room.power_levels event in the room, allow
@ -1054,6 +1065,7 @@ fn verify_third_party_invite(
// If there is no m.room.third_party_invite event in the current room state with
// state_key matching token, reject
#[allow(clippy::manual_let_else)]
let current_tpid = match current_third_party_invite {
| Some(id) => id,
| None => return false,
@ -1069,12 +1081,14 @@ fn verify_third_party_invite(
// If any signature in signed matches any public key in the
// m.room.third_party_invite event, allow
#[allow(clippy::manual_let_else)]
let tpid_ev =
match from_json_str::<RoomThirdPartyInviteEventContent>(current_tpid.content().get()) {
| Ok(ev) => ev,
| Err(_) => return false,
};
#[allow(clippy::manual_let_else)]
let decoded_invite_token = match Base64::parse(&tp_id.signed.token) {
| Ok(tok) => tok,
// FIXME: Log a warning?
@ -1096,7 +1110,7 @@ fn verify_third_party_invite(
mod tests {
use std::sync::Arc;
use ruma_events::{
use ruma::events::{
room::{
join_rules::{
AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership,
@ -1107,7 +1121,7 @@ mod tests {
};
use serde_json::value::to_raw_value as to_raw_json_value;
use crate::{
use crate::state_res::{
event_auth::valid_membership_change,
test_utils::{
alice, charlie, ella, event_id, member_content_ban, member_content_join, room_id,
@ -1145,16 +1159,16 @@ mod tests {
assert!(valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()),
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.to_string()),
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap());
}
@ -1188,16 +1202,16 @@ mod tests {
assert!(!valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()),
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.to_string()),
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap());
}
@ -1231,16 +1245,16 @@ mod tests {
assert!(valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()),
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.to_string()),
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap());
}
@ -1274,16 +1288,16 @@ mod tests {
assert!(!valid_membership_change(
&RoomVersion::V6,
target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()),
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.to_string()),
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap());
}
@ -1334,32 +1348,32 @@ mod tests {
assert!(valid_membership_change(
&RoomVersion::V9,
target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()),
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.to_string()),
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
Some(alice()),
&MembershipState::Join,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap());
assert!(!valid_membership_change(
&RoomVersion::V9,
target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()),
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.to_string()),
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
Some(ella()),
&MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap());
}
@ -1402,16 +1416,16 @@ mod tests {
assert!(valid_membership_change(
&RoomVersion::V7,
target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()),
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender,
fetch_state(StateEventType::RoomMember, sender.to_string()),
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester,
None::<PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None,
&MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
)
.unwrap());
}

View file

@ -1,3 +1,5 @@
#![cfg_attr(test, allow(warnings))]
pub(crate) mod error;
pub mod event_auth;
mod power_levels;
@ -12,7 +14,7 @@ use std::{
cmp::{Ordering, Reverse},
collections::{BinaryHeap, HashMap, HashSet},
fmt::Debug,
hash::Hash,
hash::{BuildHasher, Hash},
};
use futures::{future, stream, Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
@ -32,13 +34,13 @@ pub use self::{
room_version::RoomVersion,
state_event::Event,
};
use crate::{debug, trace, warn};
use crate::{debug, pdu::StateKey, trace, warn};
/// A mapping of event type and state_key to some value `T`, usually an
/// `EventId`.
pub type StateMap<T> = HashMap<TypeStateKey, T>;
pub type StateMapItem<T> = (TypeStateKey, T);
pub type TypeStateKey = (StateEventType, String);
pub type TypeStateKey = (StateEventType, StateKey);
type Result<T, E = Error> = crate::Result<T, E>;
@ -68,10 +70,10 @@ type Result<T, E = Error> = crate::Result<T, E>;
/// event is part of the same room.
//#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets,
//#[tracing::instrument(level event_fetch))]
pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>(
pub async fn resolve<'a, E, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>(
room_version: &RoomVersionId,
state_sets: impl IntoIterator<IntoIter = SetIter> + Send,
auth_chain_sets: &'a [HashSet<E::Id>],
state_sets: Sets,
auth_chain_sets: &'a [HashSet<E::Id, Hasher>],
event_fetch: &Fetch,
event_exists: &Exists,
parallel_fetches: usize,
@ -81,7 +83,9 @@ where
FetchFut: Future<Output = Option<E>> + Send,
Exists: Fn(E::Id) -> ExistsFut + Sync,
ExistsFut: Future<Output = bool> + Send,
Sets: IntoIterator<IntoIter = SetIter> + Send,
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone + Send,
Hasher: BuildHasher + Send + Sync,
E: Event + Clone + Send + Sync,
E::Id: Borrow<EventId> + Send + Sync,
for<'b> &'b E: Send,
@ -178,7 +182,7 @@ where
trace!(list = ?events_to_resolve, "events left to resolve");
// This "epochs" power level event
let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, String::new()));
let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, StateKey::new()));
debug!(event_id = ?power_event, "power event");
@ -222,16 +226,17 @@ fn separate<'a, Id>(
where
Id: Clone + Eq + Hash + 'a,
{
let mut state_set_count = 0_usize;
let mut state_set_count: usize = 0;
let mut occurrences = HashMap::<_, HashMap<_, _>>::new();
let state_sets_iter = state_sets_iter.inspect(|_| state_set_count += 1);
let state_sets_iter =
state_sets_iter.inspect(|_| state_set_count = state_set_count.saturating_add(1));
for (k, v) in state_sets_iter.flatten() {
occurrences
.entry(k)
.or_default()
.entry(v)
.and_modify(|x| *x += 1)
.and_modify(|x: &mut usize| *x = x.saturating_add(1))
.or_insert(1);
}
@ -246,7 +251,7 @@ where
conflicted_state
.entry((k.0.clone(), k.1.clone()))
.and_modify(|x: &mut Vec<_>| x.push(id.clone()))
.or_insert(vec![id.clone()]);
.or_insert_with(|| vec![id.clone()]);
}
}
}
@ -255,9 +260,13 @@ where
}
/// Returns a Vec of deduped EventIds that appear in some chains but not others.
fn get_auth_chain_diff<Id>(auth_chain_sets: &[HashSet<Id>]) -> impl Iterator<Item = Id> + Send
#[allow(clippy::arithmetic_side_effects)]
fn get_auth_chain_diff<Id, Hasher>(
auth_chain_sets: &[HashSet<Id, Hasher>],
) -> impl Iterator<Item = Id> + Send
where
Id: Clone + Eq + Hash + Send,
Hasher: BuildHasher + Send + Sync,
{
let num_sets = auth_chain_sets.len();
let mut id_counts: HashMap<Id, usize> = HashMap::new();
@ -288,7 +297,7 @@ async fn reverse_topological_power_sort<E, F, Fut>(
where
F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E: Event + Send + Sync,
E::Id: Borrow<EventId> + Send + Sync,
{
debug!("reverse topological sort of power events");
@ -337,14 +346,15 @@ where
/// `key_fn` is used as to obtain the power level and age of an event for
/// breaking ties (together with the event ID).
#[tracing::instrument(level = "debug", skip_all)]
pub async fn lexicographical_topological_sort<Id, F, Fut>(
graph: &HashMap<Id, HashSet<Id>>,
pub async fn lexicographical_topological_sort<Id, F, Fut, Hasher>(
graph: &HashMap<Id, HashSet<Id, Hasher>>,
key_fn: &F,
) -> Result<Vec<Id>>
where
F: Fn(Id) -> Fut + Sync,
Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send,
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send,
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send + Sync,
Hasher: BuildHasher + Default + Clone + Send + Sync,
{
#[derive(PartialEq, Eq)]
struct TieBreaker<'a, Id> {
@ -395,7 +405,7 @@ where
// The number of events that depend on the given event (the EventId key)
// How many events reference this event in the DAG as a parent
let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new();
let mut reverse_graph: HashMap<_, HashSet<_, Hasher>> = HashMap::new();
// Vec of nodes that have zero out degree, least recent events.
let mut zero_outdegree = Vec::new();
@ -727,8 +737,8 @@ async fn get_mainline_depth<E, F, Fut>(
where
F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send,
E: Event + Send + Sync,
E::Id: Borrow<EventId> + Send + Sync,
{
while let Some(sort_ev) = event {
debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline");
@ -758,10 +768,10 @@ async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
auth_diff: &HashSet<E::Id>,
fetch_event: &F,
) where
F: Fn(E::Id) -> Fut,
F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Clone + Send,
E: Event + Send + Sync,
E::Id: Borrow<EventId> + Clone + Send + Sync,
{
let mut state = vec![event_id];
while let Some(eid) = state.pop() {
@ -788,7 +798,7 @@ where
F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send,
E::Id: Borrow<EventId> + Send + Sync,
{
match fetch(event_id.clone()).await.as_ref() {
| Some(state) => is_power_event(state),
@ -820,18 +830,18 @@ fn is_power_event(event: impl Event) -> bool {
/// Convenience trait for adding event type plus state key to state maps.
pub trait EventTypeExt {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey);
}
impl EventTypeExt for StateEventType {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
(self, state_key.into())
}
}
impl EventTypeExt for TimelineEventType {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
(self.to_string().into(), state_key.into())
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
(self.into(), state_key.into())
}
}
@ -839,7 +849,7 @@ impl<T> EventTypeExt for &T
where
T: EventTypeExt + Clone,
{
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
self.to_owned().with_state_key(state_key)
}
}
@ -858,13 +868,11 @@ mod tests {
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
StateEventType, TimelineEventType,
},
int, uint,
int, uint, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
};
use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId};
use serde_json::{json, value::to_raw_value as to_raw_json_value};
use tracing::debug;
use crate::{
use super::{
is_power_event,
room_version::RoomVersion,
test_utils::{
@ -874,6 +882,7 @@ mod tests {
},
Event, EventTypeExt, StateMap,
};
use crate::debug;
async fn test_event_sort() {
use futures::future::ready;
@ -898,11 +907,11 @@ mod tests {
let fetcher = |id| ready(events.get(&id).cloned());
let sorted_power_events =
crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1)
super::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1)
.await
.unwrap();
let resolved_power = crate::iterative_auth_check(
let resolved_power = super::iterative_auth_check(
&RoomVersion::V6,
sorted_power_events.iter(),
HashMap::new(), // unconflicted events
@ -918,10 +927,10 @@ mod tests {
events_to_sort.shuffle(&mut rand::thread_rng());
let power_level = resolved_power
.get(&(StateEventType::RoomPowerLevels, "".to_owned()))
.get(&(StateEventType::RoomPowerLevels, "".into()))
.cloned();
let sorted_event_ids = crate::mainline_sort(&events_to_sort, power_level, &fetcher, 1)
let sorted_event_ids = super::mainline_sort(&events_to_sort, power_level, &fetcher, 1)
.await
.unwrap();
@ -1302,7 +1311,7 @@ mod tests {
})
.collect();
let resolved = match crate::resolve(
let resolved = match super::resolve(
&RoomVersionId::V2,
&state_sets,
&auth_chain,
@ -1333,7 +1342,7 @@ mod tests {
event_id("p") => hashset![event_id("o")],
};
let res = crate::lexicographical_topological_sort(&graph, &|_id| async {
let res = super::lexicographical_topological_sort(&graph, &|_id| async {
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
})
.await
@ -1421,7 +1430,7 @@ mod tests {
let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some());
let resolved = match crate::resolve(
let resolved = match super::resolve(
&RoomVersionId::V6,
&state_sets,
&auth_chain,
@ -1552,7 +1561,7 @@ mod tests {
#[allow(unused_mut)]
let mut x = StateMap::new();
$(
x.insert(($kind, $key.to_owned()), $id);
x.insert(($kind, $key.into()), $id);
)*
x
}};

View file

@ -32,6 +32,7 @@ pub enum StateResolutionVersion {
}
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[allow(clippy::struct_excessive_bools)]
pub struct RoomVersion {
/// The stability of this room.
pub disposition: RoomDisposition,

View file

@ -7,28 +7,28 @@ use std::{
},
};
use futures_util::future::ready;
use js_int::{int, uint};
use ruma_common::{
event_id, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId,
RoomVersionId, ServerSignatures, UserId,
};
use ruma_events::{
use futures::future::ready;
use ruma::{
event_id,
events::{
pdu::{EventHash, Pdu, RoomV3Pdu},
room::{
join_rules::{JoinRule, RoomJoinRulesEventContent},
member::{MembershipState, RoomMemberEventContent},
},
TimelineEventType,
},
int, room_id, uint, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId,
RoomVersionId, ServerSignatures, UserId,
};
use serde_json::{
json,
value::{to_raw_value as to_raw_json_value, RawValue as RawJsonValue},
};
use tracing::info;
pub(crate) use self::event::PduEvent;
use crate::{auth_types_for_event, Error, Event, EventTypeExt, Result, StateMap};
use super::auth_types_for_event;
use crate::{info, Event, EventTypeExt, Result, StateMap};
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
@ -88,7 +88,7 @@ pub(crate) async fn do_check(
// Resolve the current state and add it to the state_at_event map then continue
// on in "time"
for node in crate::lexicographical_topological_sort(&graph, &|_id| async {
for node in super::lexicographical_topological_sort(&graph, &|_id| async {
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
})
.await
@ -135,7 +135,7 @@ pub(crate) async fn do_check(
let event_map = &event_map;
let fetch = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).is_some());
let resolved = crate::resolve(
let resolved = super::resolve(
&RoomVersionId::V6,
state_sets,
&auth_chain_sets,
@ -223,7 +223,7 @@ pub(crate) async fn do_check(
// Filter out the dummy messages events.
// These act as points in time where there should be a known state to
// test against.
&& **k != ("m.room.message".into(), "dummy".to_owned())
&& **k != ("m.room.message".into(), "dummy".into())
})
.map(|(k, v)| (k.clone(), v.clone()))
.collect::<StateMap<OwnedEventId>>();
@ -239,7 +239,8 @@ impl<E: Event> TestStore<E> {
self.0
.get(event_id)
.cloned()
.ok_or_else(|| Error::NotFound(format!("{event_id} not found")))
.ok_or_else(|| super::Error::NotFound(format!("{event_id} not found")))
.map_err(Into::into)
}
/// Returns a Vec of the related auth events to the given `event`.
@ -582,8 +583,10 @@ pub(crate) fn INITIAL_EDGES() -> Vec<OwnedEventId> {
}
pub(crate) mod event {
use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId};
use ruma_events::{pdu::Pdu, TimelineEventType};
use ruma::{
events::{pdu::Pdu, TimelineEventType},
MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId,
};
use serde::{Deserialize, Serialize};
use serde_json::value::RawValue as RawJsonValue;

View file

@ -133,7 +133,7 @@ pub(super) async fn handle_outlier_pdu<'a>(
));
}
let state_fetch = |ty: &'static StateEventType, sk: &str| {
let state_fetch = |ty: &StateEventType, sk: &str| {
let key = (ty.to_owned(), sk.into());
ready(auth_events.get(&key))
};

View file

@ -63,7 +63,6 @@ pub async fn resolve_state(
.multi_get_statekey_from_short(shortstatekeys)
.zip(event_ids)
.ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id)))
.map(|((ty, sk), id)| ((ty, sk.as_str().to_owned()), id))
.collect()
})
.map(Ok::<_, Error>)

View file

@ -172,7 +172,6 @@ async fn state_at_incoming_fork(
.short
.get_statekey_from_short(*k)
.map_ok(|(ty, sk)| ((ty, sk), id.clone()))
.map_ok(|((ty, sk), id)| ((ty, sk.as_str().to_owned()), id))
})
.ready_filter_map(Result::ok)
.collect()

View file

@ -3,7 +3,7 @@ use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::In
use conduwuit::{
debug, debug_info, err, implement, state_res, trace,
utils::stream::{BroadbandExt, ReadyExt},
warn, Err, EventTypeExt, PduEvent, Result,
warn, Err, EventTypeExt, PduEvent, Result, StateKey,
};
use futures::{future::ready, FutureExt, StreamExt};
use ruma::{events::StateEventType, CanonicalJsonValue, RoomId, ServerName};
@ -71,8 +71,8 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
debug!("Performing auth check");
// 11. Check the auth of the event passes based on the state of the event
let state_fetch_state = &state_at_incoming_event;
let state_fetch = |k: &'static StateEventType, s: String| async move {
let shortstatekey = self.services.short.get_shortstatekey(k, &s).await.ok()?;
let state_fetch = |k: StateEventType, s: StateKey| async move {
let shortstatekey = self.services.short.get_shortstatekey(&k, &s).await.ok()?;
let event_id = state_fetch_state.get(&shortstatekey)?;
self.services.timeline.get_pdu(event_id).await.ok()
@ -82,7 +82,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
&room_version,
&incoming_pdu,
None, // TODO: third party invite
|k, s| state_fetch(k, s.to_owned()),
|ty, sk| state_fetch(ty.clone(), sk.into()),
)
.await
.map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?;
@ -104,7 +104,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
)
.await?;
let state_fetch = |k: &'static StateEventType, s: &str| {
let state_fetch = |k: &StateEventType, s: &str| {
let key = k.with_state_key(s);
ready(auth_events.get(&key).cloned())
};

View file

@ -747,7 +747,7 @@ impl Service {
};
let auth_fetch = |k: &StateEventType, s: &str| {
let key = (k.clone(), s.to_owned());
let key = (k.clone(), s.into());
ready(auth_events.get(&key))
};