optimize further into state-res with SmallString

triage and de-lints for state-res.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-02-08 01:58:13 +00:00 committed by strawberry
parent 0a9a9b3c92
commit f2ca670c3b
15 changed files with 192 additions and 145 deletions

1
Cargo.lock generated
View file

@ -810,6 +810,7 @@ dependencies = [
"libc", "libc",
"libloading", "libloading",
"log", "log",
"maplit",
"nix", "nix",
"num-traits", "num-traits",
"rand", "rand",

View file

@ -379,6 +379,7 @@ features = [
"unstable-msc4203", # sending to-device events to appservices "unstable-msc4203", # sending to-device events to appservices
"unstable-msc4210", # remove legacy mentions "unstable-msc4210", # remove legacy mentions
"unstable-extensible-events", "unstable-extensible-events",
"unstable-pdu",
] ]
[workspace.dependencies.rust-rocksdb] [workspace.dependencies.rust-rocksdb]
@ -527,6 +528,9 @@ features = ["std"]
version = "0.3.2" version = "0.3.2"
features = ["std"] features = ["std"]
[workspace.dependencies.maplit]
version = "1.0.2"
# #
# Patches # Patches
# #

View file

@ -14,7 +14,7 @@ use conduwuit::{
result::FlatOk, result::FlatOk,
state_res, trace, state_res, trace,
utils::{self, shuffle, IterStream, ReadyExt}, utils::{self, shuffle, IterStream, ReadyExt},
warn, Err, PduEvent, Result, warn, Err, PduEvent, Result, StateKey,
}; };
use futures::{join, FutureExt, StreamExt, TryFutureExt}; use futures::{join, FutureExt, StreamExt, TryFutureExt};
use ruma::{ use ruma::{
@ -1151,8 +1151,8 @@ async fn join_room_by_id_helper_remote(
debug!("Running send_join auth check"); debug!("Running send_join auth check");
let fetch_state = &state; let fetch_state = &state;
let state_fetch = |k: &'static StateEventType, s: String| async move { let state_fetch = |k: StateEventType, s: StateKey| async move {
let shortstatekey = services.rooms.short.get_shortstatekey(k, &s).await.ok()?; let shortstatekey = services.rooms.short.get_shortstatekey(&k, &s).await.ok()?;
let event_id = fetch_state.get(&shortstatekey)?; let event_id = fetch_state.get(&shortstatekey)?;
services.rooms.timeline.get_pdu(event_id).await.ok() services.rooms.timeline.get_pdu(event_id).await.ok()
@ -1162,7 +1162,7 @@ async fn join_room_by_id_helper_remote(
&state_res::RoomVersion::new(&room_version_id)?, &state_res::RoomVersion::new(&room_version_id)?,
&parsed_join_pdu, &parsed_join_pdu,
None, // TODO: third party invite None, // TODO: third party invite
|k, s| state_fetch(k, s.to_owned()), |k, s| state_fetch(k.clone(), s.into()),
) )
.await .await
.map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?; .map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?;

View file

@ -395,9 +395,12 @@ pub(crate) async fn sync_events_v4_route(
.map_or(10, usize_from_u64_truncated) .map_or(10, usize_from_u64_truncated)
.min(100); .min(100);
todo_room todo_room.0.extend(
.0 list.room_details
.extend(list.room_details.required_state.iter().cloned()); .required_state
.iter()
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
);
todo_room.1 = todo_room.1.max(limit); todo_room.1 = todo_room.1.max(limit);
// 0 means unknown because it got out of date // 0 means unknown because it got out of date
@ -449,7 +452,11 @@ pub(crate) async fn sync_events_v4_route(
.map_or(10, usize_from_u64_truncated) .map_or(10, usize_from_u64_truncated)
.min(100); .min(100);
todo_room.0.extend(room.required_state.iter().cloned()); todo_room.0.extend(
room.required_state
.iter()
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
);
todo_room.1 = todo_room.1.max(limit); todo_room.1 = todo_room.1.max(limit);
// 0 means unknown because it got out of date // 0 means unknown because it got out of date
todo_room.2 = todo_room.2.min( todo_room.2 = todo_room.2.min(

View file

@ -223,7 +223,11 @@ async fn fetch_subscriptions(
let limit: UInt = room.timeline_limit; let limit: UInt = room.timeline_limit;
todo_room.0.extend(room.required_state.iter().cloned()); todo_room.0.extend(
room.required_state
.iter()
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
);
todo_room.1 = todo_room.1.max(usize_from_ruma(limit)); todo_room.1 = todo_room.1.max(usize_from_ruma(limit));
// 0 means unknown because it got out of date // 0 means unknown because it got out of date
todo_room.2 = todo_room.2.min( todo_room.2 = todo_room.2.min(
@ -303,9 +307,12 @@ async fn handle_lists<'a>(
let limit: usize = usize_from_ruma(list.room_details.timeline_limit).min(100); let limit: usize = usize_from_ruma(list.room_details.timeline_limit).min(100);
todo_room todo_room.0.extend(
.0 list.room_details
.extend(list.room_details.required_state.iter().cloned()); .required_state
.iter()
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
);
todo_room.1 = todo_room.1.max(limit); todo_room.1 = todo_room.1.max(limit);
// 0 means unknown because it got out of date // 0 means unknown because it got out of date

View file

@ -116,5 +116,8 @@ nix.workspace = true
hardened_malloc-rs.workspace = true hardened_malloc-rs.workspace = true
hardened_malloc-rs.optional = true hardened_malloc-rs.optional = true
[dev-dependencies]
maplit.workspace = true
[lints] [lints]
workspace = true workspace = true

View file

@ -21,7 +21,6 @@ use serde::{
Deserialize, Deserialize,
}; };
use serde_json::{from_str as from_json_str, value::RawValue as RawJsonValue}; use serde_json::{from_str as from_json_str, value::RawValue as RawJsonValue};
use tracing::{debug, error, instrument, trace, warn};
use super::{ use super::{
power_levels::{ power_levels::{
@ -29,8 +28,9 @@ use super::{
deserialize_power_levels_content_invite, deserialize_power_levels_content_redact, deserialize_power_levels_content_invite, deserialize_power_levels_content_redact,
}, },
room_version::RoomVersion, room_version::RoomVersion,
Error, Event, Result, StateEventType, TimelineEventType, Error, Event, Result, StateEventType, StateKey, TimelineEventType,
}; };
use crate::{debug, error, trace, warn};
// FIXME: field extracting could be bundled for `content` // FIXME: field extracting could be bundled for `content`
#[derive(Deserialize)] #[derive(Deserialize)]
@ -56,15 +56,15 @@ pub fn auth_types_for_event(
sender: &UserId, sender: &UserId,
state_key: Option<&str>, state_key: Option<&str>,
content: &RawJsonValue, content: &RawJsonValue,
) -> serde_json::Result<Vec<(StateEventType, String)>> { ) -> serde_json::Result<Vec<(StateEventType, StateKey)>> {
if kind == &TimelineEventType::RoomCreate { if kind == &TimelineEventType::RoomCreate {
return Ok(vec![]); return Ok(vec![]);
} }
let mut auth_types = vec![ let mut auth_types = vec![
(StateEventType::RoomPowerLevels, String::new()), (StateEventType::RoomPowerLevels, StateKey::new()),
(StateEventType::RoomMember, sender.to_string()), (StateEventType::RoomMember, sender.as_str().into()),
(StateEventType::RoomCreate, String::new()), (StateEventType::RoomCreate, StateKey::new()),
]; ];
if kind == &TimelineEventType::RoomMember { if kind == &TimelineEventType::RoomMember {
@ -82,7 +82,7 @@ pub fn auth_types_for_event(
if [MembershipState::Join, MembershipState::Invite, MembershipState::Knock] if [MembershipState::Join, MembershipState::Invite, MembershipState::Knock]
.contains(&membership) .contains(&membership)
{ {
let key = (StateEventType::RoomJoinRules, String::new()); let key = (StateEventType::RoomJoinRules, StateKey::new());
if !auth_types.contains(&key) { if !auth_types.contains(&key) {
auth_types.push(key); auth_types.push(key);
} }
@ -91,21 +91,22 @@ pub fn auth_types_for_event(
.join_authorised_via_users_server .join_authorised_via_users_server
.map(|m| m.deserialize()) .map(|m| m.deserialize())
{ {
let key = (StateEventType::RoomMember, u.to_string()); let key = (StateEventType::RoomMember, u.as_str().into());
if !auth_types.contains(&key) { if !auth_types.contains(&key) {
auth_types.push(key); auth_types.push(key);
} }
} }
} }
let key = (StateEventType::RoomMember, state_key.to_owned()); let key = (StateEventType::RoomMember, state_key.into());
if !auth_types.contains(&key) { if !auth_types.contains(&key) {
auth_types.push(key); auth_types.push(key);
} }
if membership == MembershipState::Invite { if membership == MembershipState::Invite {
if let Some(Ok(t_id)) = content.third_party_invite.map(|t| t.deserialize()) { if let Some(Ok(t_id)) = content.third_party_invite.map(|t| t.deserialize()) {
let key = (StateEventType::RoomThirdPartyInvite, t_id.signed.token); let key =
(StateEventType::RoomThirdPartyInvite, t_id.signed.token.into());
if !auth_types.contains(&key) { if !auth_types.contains(&key) {
auth_types.push(key); auth_types.push(key);
} }
@ -128,7 +129,13 @@ pub fn auth_types_for_event(
/// The `fetch_state` closure should gather state from a state snapshot. We need /// The `fetch_state` closure should gather state from a state snapshot. We need
/// to know if the event passes auth against some state not a recursive /// to know if the event passes auth against some state not a recursive
/// collection of auth_events fields. /// collection of auth_events fields.
#[instrument(level = "debug", skip_all, fields(event_id = incoming_event.event_id().borrow().as_str()))] #[tracing::instrument(
level = "debug",
skip_all,
fields(
event_id = incoming_event.event_id().borrow().as_str()
)
)]
pub async fn auth_check<F, Fut, Fetched, Incoming>( pub async fn auth_check<F, Fut, Fetched, Incoming>(
room_version: &RoomVersion, room_version: &RoomVersion,
incoming_event: &Incoming, incoming_event: &Incoming,
@ -136,10 +143,10 @@ pub async fn auth_check<F, Fut, Fetched, Incoming>(
fetch_state: F, fetch_state: F,
) -> Result<bool, Error> ) -> Result<bool, Error>
where where
F: Fn(&'static StateEventType, &str) -> Fut, F: Fn(&StateEventType, &str) -> Fut + Send,
Fut: Future<Output = Option<Fetched>> + Send, Fut: Future<Output = Option<Fetched>> + Send,
Fetched: Event + Send, Fetched: Event + Send,
Incoming: Event + Send, Incoming: Event + Send + Sync,
{ {
debug!( debug!(
"auth_check beginning for {} ({})", "auth_check beginning for {} ({})",
@ -262,6 +269,7 @@ where
// sender domain of the event does not match the sender domain of the create // sender domain of the event does not match the sender domain of the create
// event, reject. // event, reject.
#[derive(Deserialize)] #[derive(Deserialize)]
#[allow(clippy::items_after_statements)]
struct RoomCreateContentFederate { struct RoomCreateContentFederate {
#[serde(rename = "m.federate", default = "ruma::serde::default_true")] #[serde(rename = "m.federate", default = "ruma::serde::default_true")]
federate: bool, federate: bool,
@ -354,7 +362,7 @@ where
join_rules_event.as_ref(), join_rules_event.as_ref(),
user_for_join_auth.as_deref(), user_for_join_auth.as_deref(),
&user_for_join_auth_membership, &user_for_join_auth_membership,
room_create_event, &room_create_event,
)? { )? {
return Ok(false); return Ok(false);
} }
@ -364,6 +372,7 @@ where
} }
// If the sender's current membership state is not join, reject // If the sender's current membership state is not join, reject
#[allow(clippy::manual_let_else)]
let sender_member_event = match sender_member_event { let sender_member_event = match sender_member_event {
| Some(mem) => mem, | Some(mem) => mem,
| None => { | None => {
@ -498,19 +507,20 @@ where
/// This is generated by calling `auth_types_for_event` with the membership /// This is generated by calling `auth_types_for_event` with the membership
/// event and the current State. /// event and the current State.
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
#[allow(clippy::cognitive_complexity)]
fn valid_membership_change( fn valid_membership_change(
room_version: &RoomVersion, room_version: &RoomVersion,
target_user: &UserId, target_user: &UserId,
target_user_membership_event: Option<impl Event>, target_user_membership_event: Option<&impl Event>,
sender: &UserId, sender: &UserId,
sender_membership_event: Option<impl Event>, sender_membership_event: Option<&impl Event>,
current_event: impl Event, current_event: impl Event,
current_third_party_invite: Option<impl Event>, current_third_party_invite: Option<&impl Event>,
power_levels_event: Option<impl Event>, power_levels_event: Option<&impl Event>,
join_rules_event: Option<impl Event>, join_rules_event: Option<&impl Event>,
user_for_join_auth: Option<&UserId>, user_for_join_auth: Option<&UserId>,
user_for_join_auth_membership: &MembershipState, user_for_join_auth_membership: &MembershipState,
create_room: impl Event, create_room: &impl Event,
) -> Result<bool> { ) -> Result<bool> {
#[derive(Deserialize)] #[derive(Deserialize)]
struct GetThirdPartyInvite { struct GetThirdPartyInvite {
@ -856,6 +866,7 @@ fn check_power_levels(
// and integers here // and integers here
debug!("validation of power event finished"); debug!("validation of power event finished");
#[allow(clippy::manual_let_else)]
let current_state = match previous_power_event { let current_state = match previous_power_event {
| Some(current_state) => current_state, | Some(current_state) => current_state,
// If there is no previous m.room.power_levels event in the room, allow // If there is no previous m.room.power_levels event in the room, allow
@ -1054,6 +1065,7 @@ fn verify_third_party_invite(
// If there is no m.room.third_party_invite event in the current room state with // If there is no m.room.third_party_invite event in the current room state with
// state_key matching token, reject // state_key matching token, reject
#[allow(clippy::manual_let_else)]
let current_tpid = match current_third_party_invite { let current_tpid = match current_third_party_invite {
| Some(id) => id, | Some(id) => id,
| None => return false, | None => return false,
@ -1069,12 +1081,14 @@ fn verify_third_party_invite(
// If any signature in signed matches any public key in the // If any signature in signed matches any public key in the
// m.room.third_party_invite event, allow // m.room.third_party_invite event, allow
#[allow(clippy::manual_let_else)]
let tpid_ev = let tpid_ev =
match from_json_str::<RoomThirdPartyInviteEventContent>(current_tpid.content().get()) { match from_json_str::<RoomThirdPartyInviteEventContent>(current_tpid.content().get()) {
| Ok(ev) => ev, | Ok(ev) => ev,
| Err(_) => return false, | Err(_) => return false,
}; };
#[allow(clippy::manual_let_else)]
let decoded_invite_token = match Base64::parse(&tp_id.signed.token) { let decoded_invite_token = match Base64::parse(&tp_id.signed.token) {
| Ok(tok) => tok, | Ok(tok) => tok,
// FIXME: Log a warning? // FIXME: Log a warning?
@ -1096,7 +1110,7 @@ fn verify_third_party_invite(
mod tests { mod tests {
use std::sync::Arc; use std::sync::Arc;
use ruma_events::{ use ruma::events::{
room::{ room::{
join_rules::{ join_rules::{
AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership, AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership,
@ -1107,7 +1121,7 @@ mod tests {
}; };
use serde_json::value::to_raw_value as to_raw_json_value; use serde_json::value::to_raw_value as to_raw_json_value;
use crate::{ use crate::state_res::{
event_auth::valid_membership_change, event_auth::valid_membership_change,
test_utils::{ test_utils::{
alice, charlie, ella, event_id, member_content_ban, member_content_join, room_id, alice, charlie, ella, event_id, member_content_ban, member_content_join, room_id,
@ -1145,16 +1159,16 @@ mod tests {
assert!(valid_membership_change( assert!(valid_membership_change(
&RoomVersion::V6, &RoomVersion::V6,
target_user, target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()), fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender, sender,
fetch_state(StateEventType::RoomMember, sender.to_string()), fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester, &requester,
None::<PduEvent>, None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()), fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None, None,
&MembershipState::Leave, &MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
) )
.unwrap()); .unwrap());
} }
@ -1188,16 +1202,16 @@ mod tests {
assert!(!valid_membership_change( assert!(!valid_membership_change(
&RoomVersion::V6, &RoomVersion::V6,
target_user, target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()), fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender, sender,
fetch_state(StateEventType::RoomMember, sender.to_string()), fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester, &requester,
None::<PduEvent>, None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()), fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None, None,
&MembershipState::Leave, &MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
) )
.unwrap()); .unwrap());
} }
@ -1231,16 +1245,16 @@ mod tests {
assert!(valid_membership_change( assert!(valid_membership_change(
&RoomVersion::V6, &RoomVersion::V6,
target_user, target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()), fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender, sender,
fetch_state(StateEventType::RoomMember, sender.to_string()), fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester, &requester,
None::<PduEvent>, None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()), fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None, None,
&MembershipState::Leave, &MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
) )
.unwrap()); .unwrap());
} }
@ -1274,16 +1288,16 @@ mod tests {
assert!(!valid_membership_change( assert!(!valid_membership_change(
&RoomVersion::V6, &RoomVersion::V6,
target_user, target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()), fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender, sender,
fetch_state(StateEventType::RoomMember, sender.to_string()), fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester, &requester,
None::<PduEvent>, None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()), fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None, None,
&MembershipState::Leave, &MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
) )
.unwrap()); .unwrap());
} }
@ -1334,32 +1348,32 @@ mod tests {
assert!(valid_membership_change( assert!(valid_membership_change(
&RoomVersion::V9, &RoomVersion::V9,
target_user, target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()), fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender, sender,
fetch_state(StateEventType::RoomMember, sender.to_string()), fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester, &requester,
None::<PduEvent>, None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()), fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
Some(alice()), Some(alice()),
&MembershipState::Join, &MembershipState::Join,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
) )
.unwrap()); .unwrap());
assert!(!valid_membership_change( assert!(!valid_membership_change(
&RoomVersion::V9, &RoomVersion::V9,
target_user, target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()), fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender, sender,
fetch_state(StateEventType::RoomMember, sender.to_string()), fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester, &requester,
None::<PduEvent>, None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()), fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
Some(ella()), Some(ella()),
&MembershipState::Leave, &MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
) )
.unwrap()); .unwrap());
} }
@ -1402,16 +1416,16 @@ mod tests {
assert!(valid_membership_change( assert!(valid_membership_change(
&RoomVersion::V7, &RoomVersion::V7,
target_user, target_user,
fetch_state(StateEventType::RoomMember, target_user.to_string()), fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
sender, sender,
fetch_state(StateEventType::RoomMember, sender.to_string()), fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
&requester, &requester,
None::<PduEvent>, None::<&PduEvent>,
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()), fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
fetch_state(StateEventType::RoomJoinRules, "".to_owned()), fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
None, None,
&MembershipState::Leave, &MembershipState::Leave,
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(), &fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
) )
.unwrap()); .unwrap());
} }

View file

@ -1,3 +1,5 @@
#![cfg_attr(test, allow(warnings))]
pub(crate) mod error; pub(crate) mod error;
pub mod event_auth; pub mod event_auth;
mod power_levels; mod power_levels;
@ -12,7 +14,7 @@ use std::{
cmp::{Ordering, Reverse}, cmp::{Ordering, Reverse},
collections::{BinaryHeap, HashMap, HashSet}, collections::{BinaryHeap, HashMap, HashSet},
fmt::Debug, fmt::Debug,
hash::Hash, hash::{BuildHasher, Hash},
}; };
use futures::{future, stream, Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt}; use futures::{future, stream, Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
@ -32,13 +34,13 @@ pub use self::{
room_version::RoomVersion, room_version::RoomVersion,
state_event::Event, state_event::Event,
}; };
use crate::{debug, trace, warn}; use crate::{debug, pdu::StateKey, trace, warn};
/// A mapping of event type and state_key to some value `T`, usually an /// A mapping of event type and state_key to some value `T`, usually an
/// `EventId`. /// `EventId`.
pub type StateMap<T> = HashMap<TypeStateKey, T>; pub type StateMap<T> = HashMap<TypeStateKey, T>;
pub type StateMapItem<T> = (TypeStateKey, T); pub type StateMapItem<T> = (TypeStateKey, T);
pub type TypeStateKey = (StateEventType, String); pub type TypeStateKey = (StateEventType, StateKey);
type Result<T, E = Error> = crate::Result<T, E>; type Result<T, E = Error> = crate::Result<T, E>;
@ -68,10 +70,10 @@ type Result<T, E = Error> = crate::Result<T, E>;
/// event is part of the same room. /// event is part of the same room.
//#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets, //#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets,
//#[tracing::instrument(level event_fetch))] //#[tracing::instrument(level event_fetch))]
pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>( pub async fn resolve<'a, E, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>(
room_version: &RoomVersionId, room_version: &RoomVersionId,
state_sets: impl IntoIterator<IntoIter = SetIter> + Send, state_sets: Sets,
auth_chain_sets: &'a [HashSet<E::Id>], auth_chain_sets: &'a [HashSet<E::Id, Hasher>],
event_fetch: &Fetch, event_fetch: &Fetch,
event_exists: &Exists, event_exists: &Exists,
parallel_fetches: usize, parallel_fetches: usize,
@ -81,7 +83,9 @@ where
FetchFut: Future<Output = Option<E>> + Send, FetchFut: Future<Output = Option<E>> + Send,
Exists: Fn(E::Id) -> ExistsFut + Sync, Exists: Fn(E::Id) -> ExistsFut + Sync,
ExistsFut: Future<Output = bool> + Send, ExistsFut: Future<Output = bool> + Send,
Sets: IntoIterator<IntoIter = SetIter> + Send,
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone + Send, SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone + Send,
Hasher: BuildHasher + Send + Sync,
E: Event + Clone + Send + Sync, E: Event + Clone + Send + Sync,
E::Id: Borrow<EventId> + Send + Sync, E::Id: Borrow<EventId> + Send + Sync,
for<'b> &'b E: Send, for<'b> &'b E: Send,
@ -178,7 +182,7 @@ where
trace!(list = ?events_to_resolve, "events left to resolve"); trace!(list = ?events_to_resolve, "events left to resolve");
// This "epochs" power level event // This "epochs" power level event
let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, String::new())); let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, StateKey::new()));
debug!(event_id = ?power_event, "power event"); debug!(event_id = ?power_event, "power event");
@ -222,16 +226,17 @@ fn separate<'a, Id>(
where where
Id: Clone + Eq + Hash + 'a, Id: Clone + Eq + Hash + 'a,
{ {
let mut state_set_count = 0_usize; let mut state_set_count: usize = 0;
let mut occurrences = HashMap::<_, HashMap<_, _>>::new(); let mut occurrences = HashMap::<_, HashMap<_, _>>::new();
let state_sets_iter = state_sets_iter.inspect(|_| state_set_count += 1); let state_sets_iter =
state_sets_iter.inspect(|_| state_set_count = state_set_count.saturating_add(1));
for (k, v) in state_sets_iter.flatten() { for (k, v) in state_sets_iter.flatten() {
occurrences occurrences
.entry(k) .entry(k)
.or_default() .or_default()
.entry(v) .entry(v)
.and_modify(|x| *x += 1) .and_modify(|x: &mut usize| *x = x.saturating_add(1))
.or_insert(1); .or_insert(1);
} }
@ -246,7 +251,7 @@ where
conflicted_state conflicted_state
.entry((k.0.clone(), k.1.clone())) .entry((k.0.clone(), k.1.clone()))
.and_modify(|x: &mut Vec<_>| x.push(id.clone())) .and_modify(|x: &mut Vec<_>| x.push(id.clone()))
.or_insert(vec![id.clone()]); .or_insert_with(|| vec![id.clone()]);
} }
} }
} }
@ -255,9 +260,13 @@ where
} }
/// Returns a Vec of deduped EventIds that appear in some chains but not others. /// Returns a Vec of deduped EventIds that appear in some chains but not others.
fn get_auth_chain_diff<Id>(auth_chain_sets: &[HashSet<Id>]) -> impl Iterator<Item = Id> + Send #[allow(clippy::arithmetic_side_effects)]
fn get_auth_chain_diff<Id, Hasher>(
auth_chain_sets: &[HashSet<Id, Hasher>],
) -> impl Iterator<Item = Id> + Send
where where
Id: Clone + Eq + Hash + Send, Id: Clone + Eq + Hash + Send,
Hasher: BuildHasher + Send + Sync,
{ {
let num_sets = auth_chain_sets.len(); let num_sets = auth_chain_sets.len();
let mut id_counts: HashMap<Id, usize> = HashMap::new(); let mut id_counts: HashMap<Id, usize> = HashMap::new();
@ -288,7 +297,7 @@ async fn reverse_topological_power_sort<E, F, Fut>(
where where
F: Fn(E::Id) -> Fut + Sync, F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send, Fut: Future<Output = Option<E>> + Send,
E: Event + Send, E: Event + Send + Sync,
E::Id: Borrow<EventId> + Send + Sync, E::Id: Borrow<EventId> + Send + Sync,
{ {
debug!("reverse topological sort of power events"); debug!("reverse topological sort of power events");
@ -337,14 +346,15 @@ where
/// `key_fn` is used as to obtain the power level and age of an event for /// `key_fn` is used as to obtain the power level and age of an event for
/// breaking ties (together with the event ID). /// breaking ties (together with the event ID).
#[tracing::instrument(level = "debug", skip_all)] #[tracing::instrument(level = "debug", skip_all)]
pub async fn lexicographical_topological_sort<Id, F, Fut>( pub async fn lexicographical_topological_sort<Id, F, Fut, Hasher>(
graph: &HashMap<Id, HashSet<Id>>, graph: &HashMap<Id, HashSet<Id, Hasher>>,
key_fn: &F, key_fn: &F,
) -> Result<Vec<Id>> ) -> Result<Vec<Id>>
where where
F: Fn(Id) -> Fut + Sync, F: Fn(Id) -> Fut + Sync,
Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send, Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send,
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send, Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send + Sync,
Hasher: BuildHasher + Default + Clone + Send + Sync,
{ {
#[derive(PartialEq, Eq)] #[derive(PartialEq, Eq)]
struct TieBreaker<'a, Id> { struct TieBreaker<'a, Id> {
@ -395,7 +405,7 @@ where
// The number of events that depend on the given event (the EventId key) // The number of events that depend on the given event (the EventId key)
// How many events reference this event in the DAG as a parent // How many events reference this event in the DAG as a parent
let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new(); let mut reverse_graph: HashMap<_, HashSet<_, Hasher>> = HashMap::new();
// Vec of nodes that have zero out degree, least recent events. // Vec of nodes that have zero out degree, least recent events.
let mut zero_outdegree = Vec::new(); let mut zero_outdegree = Vec::new();
@ -727,8 +737,8 @@ async fn get_mainline_depth<E, F, Fut>(
where where
F: Fn(E::Id) -> Fut + Sync, F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send, Fut: Future<Output = Option<E>> + Send,
E: Event + Send, E: Event + Send + Sync,
E::Id: Borrow<EventId> + Send, E::Id: Borrow<EventId> + Send + Sync,
{ {
while let Some(sort_ev) = event { while let Some(sort_ev) = event {
debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline"); debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline");
@ -758,10 +768,10 @@ async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
auth_diff: &HashSet<E::Id>, auth_diff: &HashSet<E::Id>,
fetch_event: &F, fetch_event: &F,
) where ) where
F: Fn(E::Id) -> Fut, F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send, Fut: Future<Output = Option<E>> + Send,
E: Event + Send, E: Event + Send + Sync,
E::Id: Borrow<EventId> + Clone + Send, E::Id: Borrow<EventId> + Clone + Send + Sync,
{ {
let mut state = vec![event_id]; let mut state = vec![event_id];
while let Some(eid) = state.pop() { while let Some(eid) = state.pop() {
@ -788,7 +798,7 @@ where
F: Fn(E::Id) -> Fut + Sync, F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send, Fut: Future<Output = Option<E>> + Send,
E: Event + Send, E: Event + Send,
E::Id: Borrow<EventId> + Send, E::Id: Borrow<EventId> + Send + Sync,
{ {
match fetch(event_id.clone()).await.as_ref() { match fetch(event_id.clone()).await.as_ref() {
| Some(state) => is_power_event(state), | Some(state) => is_power_event(state),
@ -820,18 +830,18 @@ fn is_power_event(event: impl Event) -> bool {
/// Convenience trait for adding event type plus state key to state maps. /// Convenience trait for adding event type plus state key to state maps.
pub trait EventTypeExt { pub trait EventTypeExt {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String); fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey);
} }
impl EventTypeExt for StateEventType { impl EventTypeExt for StateEventType {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) { fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
(self, state_key.into()) (self, state_key.into())
} }
} }
impl EventTypeExt for TimelineEventType { impl EventTypeExt for TimelineEventType {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) { fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
(self.to_string().into(), state_key.into()) (self.into(), state_key.into())
} }
} }
@ -839,7 +849,7 @@ impl<T> EventTypeExt for &T
where where
T: EventTypeExt + Clone, T: EventTypeExt + Clone,
{ {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) { fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
self.to_owned().with_state_key(state_key) self.to_owned().with_state_key(state_key)
} }
} }
@ -858,13 +868,11 @@ mod tests {
room::join_rules::{JoinRule, RoomJoinRulesEventContent}, room::join_rules::{JoinRule, RoomJoinRulesEventContent},
StateEventType, TimelineEventType, StateEventType, TimelineEventType,
}, },
int, uint, int, uint, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
}; };
use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId};
use serde_json::{json, value::to_raw_value as to_raw_json_value}; use serde_json::{json, value::to_raw_value as to_raw_json_value};
use tracing::debug;
use crate::{ use super::{
is_power_event, is_power_event,
room_version::RoomVersion, room_version::RoomVersion,
test_utils::{ test_utils::{
@ -874,6 +882,7 @@ mod tests {
}, },
Event, EventTypeExt, StateMap, Event, EventTypeExt, StateMap,
}; };
use crate::debug;
async fn test_event_sort() { async fn test_event_sort() {
use futures::future::ready; use futures::future::ready;
@ -898,11 +907,11 @@ mod tests {
let fetcher = |id| ready(events.get(&id).cloned()); let fetcher = |id| ready(events.get(&id).cloned());
let sorted_power_events = let sorted_power_events =
crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1) super::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1)
.await .await
.unwrap(); .unwrap();
let resolved_power = crate::iterative_auth_check( let resolved_power = super::iterative_auth_check(
&RoomVersion::V6, &RoomVersion::V6,
sorted_power_events.iter(), sorted_power_events.iter(),
HashMap::new(), // unconflicted events HashMap::new(), // unconflicted events
@ -918,10 +927,10 @@ mod tests {
events_to_sort.shuffle(&mut rand::thread_rng()); events_to_sort.shuffle(&mut rand::thread_rng());
let power_level = resolved_power let power_level = resolved_power
.get(&(StateEventType::RoomPowerLevels, "".to_owned())) .get(&(StateEventType::RoomPowerLevels, "".into()))
.cloned(); .cloned();
let sorted_event_ids = crate::mainline_sort(&events_to_sort, power_level, &fetcher, 1) let sorted_event_ids = super::mainline_sort(&events_to_sort, power_level, &fetcher, 1)
.await .await
.unwrap(); .unwrap();
@ -1302,7 +1311,7 @@ mod tests {
}) })
.collect(); .collect();
let resolved = match crate::resolve( let resolved = match super::resolve(
&RoomVersionId::V2, &RoomVersionId::V2,
&state_sets, &state_sets,
&auth_chain, &auth_chain,
@ -1333,7 +1342,7 @@ mod tests {
event_id("p") => hashset![event_id("o")], event_id("p") => hashset![event_id("o")],
}; };
let res = crate::lexicographical_topological_sort(&graph, &|_id| async { let res = super::lexicographical_topological_sort(&graph, &|_id| async {
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
}) })
.await .await
@ -1421,7 +1430,7 @@ mod tests {
let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned()); let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some()); let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some());
let resolved = match crate::resolve( let resolved = match super::resolve(
&RoomVersionId::V6, &RoomVersionId::V6,
&state_sets, &state_sets,
&auth_chain, &auth_chain,
@ -1552,7 +1561,7 @@ mod tests {
#[allow(unused_mut)] #[allow(unused_mut)]
let mut x = StateMap::new(); let mut x = StateMap::new();
$( $(
x.insert(($kind, $key.to_owned()), $id); x.insert(($kind, $key.into()), $id);
)* )*
x x
}}; }};

View file

@ -32,6 +32,7 @@ pub enum StateResolutionVersion {
} }
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)] #[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
#[allow(clippy::struct_excessive_bools)]
pub struct RoomVersion { pub struct RoomVersion {
/// The stability of this room. /// The stability of this room.
pub disposition: RoomDisposition, pub disposition: RoomDisposition,

View file

@ -7,28 +7,28 @@ use std::{
}, },
}; };
use futures_util::future::ready; use futures::future::ready;
use js_int::{int, uint}; use ruma::{
use ruma_common::{ event_id,
event_id, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, events::{
RoomVersionId, ServerSignatures, UserId, pdu::{EventHash, Pdu, RoomV3Pdu},
}; room::{
use ruma_events::{ join_rules::{JoinRule, RoomJoinRulesEventContent},
pdu::{EventHash, Pdu, RoomV3Pdu}, member::{MembershipState, RoomMemberEventContent},
room::{ },
join_rules::{JoinRule, RoomJoinRulesEventContent}, TimelineEventType,
member::{MembershipState, RoomMemberEventContent},
}, },
TimelineEventType, int, room_id, uint, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId,
RoomVersionId, ServerSignatures, UserId,
}; };
use serde_json::{ use serde_json::{
json, json,
value::{to_raw_value as to_raw_json_value, RawValue as RawJsonValue}, value::{to_raw_value as to_raw_json_value, RawValue as RawJsonValue},
}; };
use tracing::info;
pub(crate) use self::event::PduEvent; pub(crate) use self::event::PduEvent;
use crate::{auth_types_for_event, Error, Event, EventTypeExt, Result, StateMap}; use super::auth_types_for_event;
use crate::{info, Event, EventTypeExt, Result, StateMap};
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0); static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
@ -88,7 +88,7 @@ pub(crate) async fn do_check(
// Resolve the current state and add it to the state_at_event map then continue // Resolve the current state and add it to the state_at_event map then continue
// on in "time" // on in "time"
for node in crate::lexicographical_topological_sort(&graph, &|_id| async { for node in super::lexicographical_topological_sort(&graph, &|_id| async {
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0)))) Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
}) })
.await .await
@ -135,7 +135,7 @@ pub(crate) async fn do_check(
let event_map = &event_map; let event_map = &event_map;
let fetch = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).cloned()); let fetch = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).is_some()); let exists = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).is_some());
let resolved = crate::resolve( let resolved = super::resolve(
&RoomVersionId::V6, &RoomVersionId::V6,
state_sets, state_sets,
&auth_chain_sets, &auth_chain_sets,
@ -223,7 +223,7 @@ pub(crate) async fn do_check(
// Filter out the dummy messages events. // Filter out the dummy messages events.
// These act as points in time where there should be a known state to // These act as points in time where there should be a known state to
// test against. // test against.
&& **k != ("m.room.message".into(), "dummy".to_owned()) && **k != ("m.room.message".into(), "dummy".into())
}) })
.map(|(k, v)| (k.clone(), v.clone())) .map(|(k, v)| (k.clone(), v.clone()))
.collect::<StateMap<OwnedEventId>>(); .collect::<StateMap<OwnedEventId>>();
@ -239,7 +239,8 @@ impl<E: Event> TestStore<E> {
self.0 self.0
.get(event_id) .get(event_id)
.cloned() .cloned()
.ok_or_else(|| Error::NotFound(format!("{event_id} not found"))) .ok_or_else(|| super::Error::NotFound(format!("{event_id} not found")))
.map_err(Into::into)
} }
/// Returns a Vec of the related auth events to the given `event`. /// Returns a Vec of the related auth events to the given `event`.
@ -582,8 +583,10 @@ pub(crate) fn INITIAL_EDGES() -> Vec<OwnedEventId> {
} }
pub(crate) mod event { pub(crate) mod event {
use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId}; use ruma::{
use ruma_events::{pdu::Pdu, TimelineEventType}; events::{pdu::Pdu, TimelineEventType},
MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId,
};
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use serde_json::value::RawValue as RawJsonValue; use serde_json::value::RawValue as RawJsonValue;

View file

@ -133,7 +133,7 @@ pub(super) async fn handle_outlier_pdu<'a>(
)); ));
} }
let state_fetch = |ty: &'static StateEventType, sk: &str| { let state_fetch = |ty: &StateEventType, sk: &str| {
let key = (ty.to_owned(), sk.into()); let key = (ty.to_owned(), sk.into());
ready(auth_events.get(&key)) ready(auth_events.get(&key))
}; };

View file

@ -63,7 +63,6 @@ pub async fn resolve_state(
.multi_get_statekey_from_short(shortstatekeys) .multi_get_statekey_from_short(shortstatekeys)
.zip(event_ids) .zip(event_ids)
.ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id))) .ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id)))
.map(|((ty, sk), id)| ((ty, sk.as_str().to_owned()), id))
.collect() .collect()
}) })
.map(Ok::<_, Error>) .map(Ok::<_, Error>)

View file

@ -172,7 +172,6 @@ async fn state_at_incoming_fork(
.short .short
.get_statekey_from_short(*k) .get_statekey_from_short(*k)
.map_ok(|(ty, sk)| ((ty, sk), id.clone())) .map_ok(|(ty, sk)| ((ty, sk), id.clone()))
.map_ok(|((ty, sk), id)| ((ty, sk.as_str().to_owned()), id))
}) })
.ready_filter_map(Result::ok) .ready_filter_map(Result::ok)
.collect() .collect()

View file

@ -3,7 +3,7 @@ use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::In
use conduwuit::{ use conduwuit::{
debug, debug_info, err, implement, state_res, trace, debug, debug_info, err, implement, state_res, trace,
utils::stream::{BroadbandExt, ReadyExt}, utils::stream::{BroadbandExt, ReadyExt},
warn, Err, EventTypeExt, PduEvent, Result, warn, Err, EventTypeExt, PduEvent, Result, StateKey,
}; };
use futures::{future::ready, FutureExt, StreamExt}; use futures::{future::ready, FutureExt, StreamExt};
use ruma::{events::StateEventType, CanonicalJsonValue, RoomId, ServerName}; use ruma::{events::StateEventType, CanonicalJsonValue, RoomId, ServerName};
@ -71,8 +71,8 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
debug!("Performing auth check"); debug!("Performing auth check");
// 11. Check the auth of the event passes based on the state of the event // 11. Check the auth of the event passes based on the state of the event
let state_fetch_state = &state_at_incoming_event; let state_fetch_state = &state_at_incoming_event;
let state_fetch = |k: &'static StateEventType, s: String| async move { let state_fetch = |k: StateEventType, s: StateKey| async move {
let shortstatekey = self.services.short.get_shortstatekey(k, &s).await.ok()?; let shortstatekey = self.services.short.get_shortstatekey(&k, &s).await.ok()?;
let event_id = state_fetch_state.get(&shortstatekey)?; let event_id = state_fetch_state.get(&shortstatekey)?;
self.services.timeline.get_pdu(event_id).await.ok() self.services.timeline.get_pdu(event_id).await.ok()
@ -82,7 +82,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
&room_version, &room_version,
&incoming_pdu, &incoming_pdu,
None, // TODO: third party invite None, // TODO: third party invite
|k, s| state_fetch(k, s.to_owned()), |ty, sk| state_fetch(ty.clone(), sk.into()),
) )
.await .await
.map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?; .map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?;
@ -104,7 +104,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
) )
.await?; .await?;
let state_fetch = |k: &'static StateEventType, s: &str| { let state_fetch = |k: &StateEventType, s: &str| {
let key = k.with_state_key(s); let key = k.with_state_key(s);
ready(auth_events.get(&key).cloned()) ready(auth_events.get(&key).cloned())
}; };

View file

@ -747,7 +747,7 @@ impl Service {
}; };
let auth_fetch = |k: &StateEventType, s: &str| { let auth_fetch = |k: &StateEventType, s: &str| {
let key = (k.clone(), s.to_owned()); let key = (k.clone(), s.into());
ready(auth_events.get(&key)) ready(auth_events.get(&key))
}; };