optimize further into state-res with SmallString
triage and de-lints for state-res. Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
parent
0a9a9b3c92
commit
f2ca670c3b
15 changed files with 192 additions and 145 deletions
1
Cargo.lock
generated
1
Cargo.lock
generated
|
@ -810,6 +810,7 @@ dependencies = [
|
||||||
"libc",
|
"libc",
|
||||||
"libloading",
|
"libloading",
|
||||||
"log",
|
"log",
|
||||||
|
"maplit",
|
||||||
"nix",
|
"nix",
|
||||||
"num-traits",
|
"num-traits",
|
||||||
"rand",
|
"rand",
|
||||||
|
|
|
@ -379,6 +379,7 @@ features = [
|
||||||
"unstable-msc4203", # sending to-device events to appservices
|
"unstable-msc4203", # sending to-device events to appservices
|
||||||
"unstable-msc4210", # remove legacy mentions
|
"unstable-msc4210", # remove legacy mentions
|
||||||
"unstable-extensible-events",
|
"unstable-extensible-events",
|
||||||
|
"unstable-pdu",
|
||||||
]
|
]
|
||||||
|
|
||||||
[workspace.dependencies.rust-rocksdb]
|
[workspace.dependencies.rust-rocksdb]
|
||||||
|
@ -527,6 +528,9 @@ features = ["std"]
|
||||||
version = "0.3.2"
|
version = "0.3.2"
|
||||||
features = ["std"]
|
features = ["std"]
|
||||||
|
|
||||||
|
[workspace.dependencies.maplit]
|
||||||
|
version = "1.0.2"
|
||||||
|
|
||||||
#
|
#
|
||||||
# Patches
|
# Patches
|
||||||
#
|
#
|
||||||
|
|
|
@ -14,7 +14,7 @@ use conduwuit::{
|
||||||
result::FlatOk,
|
result::FlatOk,
|
||||||
state_res, trace,
|
state_res, trace,
|
||||||
utils::{self, shuffle, IterStream, ReadyExt},
|
utils::{self, shuffle, IterStream, ReadyExt},
|
||||||
warn, Err, PduEvent, Result,
|
warn, Err, PduEvent, Result, StateKey,
|
||||||
};
|
};
|
||||||
use futures::{join, FutureExt, StreamExt, TryFutureExt};
|
use futures::{join, FutureExt, StreamExt, TryFutureExt};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
|
@ -1151,8 +1151,8 @@ async fn join_room_by_id_helper_remote(
|
||||||
|
|
||||||
debug!("Running send_join auth check");
|
debug!("Running send_join auth check");
|
||||||
let fetch_state = &state;
|
let fetch_state = &state;
|
||||||
let state_fetch = |k: &'static StateEventType, s: String| async move {
|
let state_fetch = |k: StateEventType, s: StateKey| async move {
|
||||||
let shortstatekey = services.rooms.short.get_shortstatekey(k, &s).await.ok()?;
|
let shortstatekey = services.rooms.short.get_shortstatekey(&k, &s).await.ok()?;
|
||||||
|
|
||||||
let event_id = fetch_state.get(&shortstatekey)?;
|
let event_id = fetch_state.get(&shortstatekey)?;
|
||||||
services.rooms.timeline.get_pdu(event_id).await.ok()
|
services.rooms.timeline.get_pdu(event_id).await.ok()
|
||||||
|
@ -1162,7 +1162,7 @@ async fn join_room_by_id_helper_remote(
|
||||||
&state_res::RoomVersion::new(&room_version_id)?,
|
&state_res::RoomVersion::new(&room_version_id)?,
|
||||||
&parsed_join_pdu,
|
&parsed_join_pdu,
|
||||||
None, // TODO: third party invite
|
None, // TODO: third party invite
|
||||||
|k, s| state_fetch(k, s.to_owned()),
|
|k, s| state_fetch(k.clone(), s.into()),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?;
|
.map_err(|e| err!(Request(Forbidden(warn!("Auth check failed: {e:?}")))))?;
|
||||||
|
|
|
@ -395,9 +395,12 @@ pub(crate) async fn sync_events_v4_route(
|
||||||
.map_or(10, usize_from_u64_truncated)
|
.map_or(10, usize_from_u64_truncated)
|
||||||
.min(100);
|
.min(100);
|
||||||
|
|
||||||
todo_room
|
todo_room.0.extend(
|
||||||
.0
|
list.room_details
|
||||||
.extend(list.room_details.required_state.iter().cloned());
|
.required_state
|
||||||
|
.iter()
|
||||||
|
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
|
||||||
|
);
|
||||||
|
|
||||||
todo_room.1 = todo_room.1.max(limit);
|
todo_room.1 = todo_room.1.max(limit);
|
||||||
// 0 means unknown because it got out of date
|
// 0 means unknown because it got out of date
|
||||||
|
@ -449,7 +452,11 @@ pub(crate) async fn sync_events_v4_route(
|
||||||
.map_or(10, usize_from_u64_truncated)
|
.map_or(10, usize_from_u64_truncated)
|
||||||
.min(100);
|
.min(100);
|
||||||
|
|
||||||
todo_room.0.extend(room.required_state.iter().cloned());
|
todo_room.0.extend(
|
||||||
|
room.required_state
|
||||||
|
.iter()
|
||||||
|
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
|
||||||
|
);
|
||||||
todo_room.1 = todo_room.1.max(limit);
|
todo_room.1 = todo_room.1.max(limit);
|
||||||
// 0 means unknown because it got out of date
|
// 0 means unknown because it got out of date
|
||||||
todo_room.2 = todo_room.2.min(
|
todo_room.2 = todo_room.2.min(
|
||||||
|
|
|
@ -223,7 +223,11 @@ async fn fetch_subscriptions(
|
||||||
|
|
||||||
let limit: UInt = room.timeline_limit;
|
let limit: UInt = room.timeline_limit;
|
||||||
|
|
||||||
todo_room.0.extend(room.required_state.iter().cloned());
|
todo_room.0.extend(
|
||||||
|
room.required_state
|
||||||
|
.iter()
|
||||||
|
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
|
||||||
|
);
|
||||||
todo_room.1 = todo_room.1.max(usize_from_ruma(limit));
|
todo_room.1 = todo_room.1.max(usize_from_ruma(limit));
|
||||||
// 0 means unknown because it got out of date
|
// 0 means unknown because it got out of date
|
||||||
todo_room.2 = todo_room.2.min(
|
todo_room.2 = todo_room.2.min(
|
||||||
|
@ -303,9 +307,12 @@ async fn handle_lists<'a>(
|
||||||
|
|
||||||
let limit: usize = usize_from_ruma(list.room_details.timeline_limit).min(100);
|
let limit: usize = usize_from_ruma(list.room_details.timeline_limit).min(100);
|
||||||
|
|
||||||
todo_room
|
todo_room.0.extend(
|
||||||
.0
|
list.room_details
|
||||||
.extend(list.room_details.required_state.iter().cloned());
|
.required_state
|
||||||
|
.iter()
|
||||||
|
.map(|(ty, sk)| (ty.clone(), sk.as_str().into())),
|
||||||
|
);
|
||||||
|
|
||||||
todo_room.1 = todo_room.1.max(limit);
|
todo_room.1 = todo_room.1.max(limit);
|
||||||
// 0 means unknown because it got out of date
|
// 0 means unknown because it got out of date
|
||||||
|
|
|
@ -116,5 +116,8 @@ nix.workspace = true
|
||||||
hardened_malloc-rs.workspace = true
|
hardened_malloc-rs.workspace = true
|
||||||
hardened_malloc-rs.optional = true
|
hardened_malloc-rs.optional = true
|
||||||
|
|
||||||
|
[dev-dependencies]
|
||||||
|
maplit.workspace = true
|
||||||
|
|
||||||
[lints]
|
[lints]
|
||||||
workspace = true
|
workspace = true
|
||||||
|
|
|
@ -21,7 +21,6 @@ use serde::{
|
||||||
Deserialize,
|
Deserialize,
|
||||||
};
|
};
|
||||||
use serde_json::{from_str as from_json_str, value::RawValue as RawJsonValue};
|
use serde_json::{from_str as from_json_str, value::RawValue as RawJsonValue};
|
||||||
use tracing::{debug, error, instrument, trace, warn};
|
|
||||||
|
|
||||||
use super::{
|
use super::{
|
||||||
power_levels::{
|
power_levels::{
|
||||||
|
@ -29,8 +28,9 @@ use super::{
|
||||||
deserialize_power_levels_content_invite, deserialize_power_levels_content_redact,
|
deserialize_power_levels_content_invite, deserialize_power_levels_content_redact,
|
||||||
},
|
},
|
||||||
room_version::RoomVersion,
|
room_version::RoomVersion,
|
||||||
Error, Event, Result, StateEventType, TimelineEventType,
|
Error, Event, Result, StateEventType, StateKey, TimelineEventType,
|
||||||
};
|
};
|
||||||
|
use crate::{debug, error, trace, warn};
|
||||||
|
|
||||||
// FIXME: field extracting could be bundled for `content`
|
// FIXME: field extracting could be bundled for `content`
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
@ -56,15 +56,15 @@ pub fn auth_types_for_event(
|
||||||
sender: &UserId,
|
sender: &UserId,
|
||||||
state_key: Option<&str>,
|
state_key: Option<&str>,
|
||||||
content: &RawJsonValue,
|
content: &RawJsonValue,
|
||||||
) -> serde_json::Result<Vec<(StateEventType, String)>> {
|
) -> serde_json::Result<Vec<(StateEventType, StateKey)>> {
|
||||||
if kind == &TimelineEventType::RoomCreate {
|
if kind == &TimelineEventType::RoomCreate {
|
||||||
return Ok(vec![]);
|
return Ok(vec![]);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut auth_types = vec![
|
let mut auth_types = vec![
|
||||||
(StateEventType::RoomPowerLevels, String::new()),
|
(StateEventType::RoomPowerLevels, StateKey::new()),
|
||||||
(StateEventType::RoomMember, sender.to_string()),
|
(StateEventType::RoomMember, sender.as_str().into()),
|
||||||
(StateEventType::RoomCreate, String::new()),
|
(StateEventType::RoomCreate, StateKey::new()),
|
||||||
];
|
];
|
||||||
|
|
||||||
if kind == &TimelineEventType::RoomMember {
|
if kind == &TimelineEventType::RoomMember {
|
||||||
|
@ -82,7 +82,7 @@ pub fn auth_types_for_event(
|
||||||
if [MembershipState::Join, MembershipState::Invite, MembershipState::Knock]
|
if [MembershipState::Join, MembershipState::Invite, MembershipState::Knock]
|
||||||
.contains(&membership)
|
.contains(&membership)
|
||||||
{
|
{
|
||||||
let key = (StateEventType::RoomJoinRules, String::new());
|
let key = (StateEventType::RoomJoinRules, StateKey::new());
|
||||||
if !auth_types.contains(&key) {
|
if !auth_types.contains(&key) {
|
||||||
auth_types.push(key);
|
auth_types.push(key);
|
||||||
}
|
}
|
||||||
|
@ -91,21 +91,22 @@ pub fn auth_types_for_event(
|
||||||
.join_authorised_via_users_server
|
.join_authorised_via_users_server
|
||||||
.map(|m| m.deserialize())
|
.map(|m| m.deserialize())
|
||||||
{
|
{
|
||||||
let key = (StateEventType::RoomMember, u.to_string());
|
let key = (StateEventType::RoomMember, u.as_str().into());
|
||||||
if !auth_types.contains(&key) {
|
if !auth_types.contains(&key) {
|
||||||
auth_types.push(key);
|
auth_types.push(key);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let key = (StateEventType::RoomMember, state_key.to_owned());
|
let key = (StateEventType::RoomMember, state_key.into());
|
||||||
if !auth_types.contains(&key) {
|
if !auth_types.contains(&key) {
|
||||||
auth_types.push(key);
|
auth_types.push(key);
|
||||||
}
|
}
|
||||||
|
|
||||||
if membership == MembershipState::Invite {
|
if membership == MembershipState::Invite {
|
||||||
if let Some(Ok(t_id)) = content.third_party_invite.map(|t| t.deserialize()) {
|
if let Some(Ok(t_id)) = content.third_party_invite.map(|t| t.deserialize()) {
|
||||||
let key = (StateEventType::RoomThirdPartyInvite, t_id.signed.token);
|
let key =
|
||||||
|
(StateEventType::RoomThirdPartyInvite, t_id.signed.token.into());
|
||||||
if !auth_types.contains(&key) {
|
if !auth_types.contains(&key) {
|
||||||
auth_types.push(key);
|
auth_types.push(key);
|
||||||
}
|
}
|
||||||
|
@ -128,7 +129,13 @@ pub fn auth_types_for_event(
|
||||||
/// The `fetch_state` closure should gather state from a state snapshot. We need
|
/// The `fetch_state` closure should gather state from a state snapshot. We need
|
||||||
/// to know if the event passes auth against some state not a recursive
|
/// to know if the event passes auth against some state not a recursive
|
||||||
/// collection of auth_events fields.
|
/// collection of auth_events fields.
|
||||||
#[instrument(level = "debug", skip_all, fields(event_id = incoming_event.event_id().borrow().as_str()))]
|
#[tracing::instrument(
|
||||||
|
level = "debug",
|
||||||
|
skip_all,
|
||||||
|
fields(
|
||||||
|
event_id = incoming_event.event_id().borrow().as_str()
|
||||||
|
)
|
||||||
|
)]
|
||||||
pub async fn auth_check<F, Fut, Fetched, Incoming>(
|
pub async fn auth_check<F, Fut, Fetched, Incoming>(
|
||||||
room_version: &RoomVersion,
|
room_version: &RoomVersion,
|
||||||
incoming_event: &Incoming,
|
incoming_event: &Incoming,
|
||||||
|
@ -136,10 +143,10 @@ pub async fn auth_check<F, Fut, Fetched, Incoming>(
|
||||||
fetch_state: F,
|
fetch_state: F,
|
||||||
) -> Result<bool, Error>
|
) -> Result<bool, Error>
|
||||||
where
|
where
|
||||||
F: Fn(&'static StateEventType, &str) -> Fut,
|
F: Fn(&StateEventType, &str) -> Fut + Send,
|
||||||
Fut: Future<Output = Option<Fetched>> + Send,
|
Fut: Future<Output = Option<Fetched>> + Send,
|
||||||
Fetched: Event + Send,
|
Fetched: Event + Send,
|
||||||
Incoming: Event + Send,
|
Incoming: Event + Send + Sync,
|
||||||
{
|
{
|
||||||
debug!(
|
debug!(
|
||||||
"auth_check beginning for {} ({})",
|
"auth_check beginning for {} ({})",
|
||||||
|
@ -262,6 +269,7 @@ where
|
||||||
// sender domain of the event does not match the sender domain of the create
|
// sender domain of the event does not match the sender domain of the create
|
||||||
// event, reject.
|
// event, reject.
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
#[allow(clippy::items_after_statements)]
|
||||||
struct RoomCreateContentFederate {
|
struct RoomCreateContentFederate {
|
||||||
#[serde(rename = "m.federate", default = "ruma::serde::default_true")]
|
#[serde(rename = "m.federate", default = "ruma::serde::default_true")]
|
||||||
federate: bool,
|
federate: bool,
|
||||||
|
@ -354,7 +362,7 @@ where
|
||||||
join_rules_event.as_ref(),
|
join_rules_event.as_ref(),
|
||||||
user_for_join_auth.as_deref(),
|
user_for_join_auth.as_deref(),
|
||||||
&user_for_join_auth_membership,
|
&user_for_join_auth_membership,
|
||||||
room_create_event,
|
&room_create_event,
|
||||||
)? {
|
)? {
|
||||||
return Ok(false);
|
return Ok(false);
|
||||||
}
|
}
|
||||||
|
@ -364,6 +372,7 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
// If the sender's current membership state is not join, reject
|
// If the sender's current membership state is not join, reject
|
||||||
|
#[allow(clippy::manual_let_else)]
|
||||||
let sender_member_event = match sender_member_event {
|
let sender_member_event = match sender_member_event {
|
||||||
| Some(mem) => mem,
|
| Some(mem) => mem,
|
||||||
| None => {
|
| None => {
|
||||||
|
@ -498,19 +507,20 @@ where
|
||||||
/// This is generated by calling `auth_types_for_event` with the membership
|
/// This is generated by calling `auth_types_for_event` with the membership
|
||||||
/// event and the current State.
|
/// event and the current State.
|
||||||
#[allow(clippy::too_many_arguments)]
|
#[allow(clippy::too_many_arguments)]
|
||||||
|
#[allow(clippy::cognitive_complexity)]
|
||||||
fn valid_membership_change(
|
fn valid_membership_change(
|
||||||
room_version: &RoomVersion,
|
room_version: &RoomVersion,
|
||||||
target_user: &UserId,
|
target_user: &UserId,
|
||||||
target_user_membership_event: Option<impl Event>,
|
target_user_membership_event: Option<&impl Event>,
|
||||||
sender: &UserId,
|
sender: &UserId,
|
||||||
sender_membership_event: Option<impl Event>,
|
sender_membership_event: Option<&impl Event>,
|
||||||
current_event: impl Event,
|
current_event: impl Event,
|
||||||
current_third_party_invite: Option<impl Event>,
|
current_third_party_invite: Option<&impl Event>,
|
||||||
power_levels_event: Option<impl Event>,
|
power_levels_event: Option<&impl Event>,
|
||||||
join_rules_event: Option<impl Event>,
|
join_rules_event: Option<&impl Event>,
|
||||||
user_for_join_auth: Option<&UserId>,
|
user_for_join_auth: Option<&UserId>,
|
||||||
user_for_join_auth_membership: &MembershipState,
|
user_for_join_auth_membership: &MembershipState,
|
||||||
create_room: impl Event,
|
create_room: &impl Event,
|
||||||
) -> Result<bool> {
|
) -> Result<bool> {
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct GetThirdPartyInvite {
|
struct GetThirdPartyInvite {
|
||||||
|
@ -856,6 +866,7 @@ fn check_power_levels(
|
||||||
// and integers here
|
// and integers here
|
||||||
debug!("validation of power event finished");
|
debug!("validation of power event finished");
|
||||||
|
|
||||||
|
#[allow(clippy::manual_let_else)]
|
||||||
let current_state = match previous_power_event {
|
let current_state = match previous_power_event {
|
||||||
| Some(current_state) => current_state,
|
| Some(current_state) => current_state,
|
||||||
// If there is no previous m.room.power_levels event in the room, allow
|
// If there is no previous m.room.power_levels event in the room, allow
|
||||||
|
@ -1054,6 +1065,7 @@ fn verify_third_party_invite(
|
||||||
|
|
||||||
// If there is no m.room.third_party_invite event in the current room state with
|
// If there is no m.room.third_party_invite event in the current room state with
|
||||||
// state_key matching token, reject
|
// state_key matching token, reject
|
||||||
|
#[allow(clippy::manual_let_else)]
|
||||||
let current_tpid = match current_third_party_invite {
|
let current_tpid = match current_third_party_invite {
|
||||||
| Some(id) => id,
|
| Some(id) => id,
|
||||||
| None => return false,
|
| None => return false,
|
||||||
|
@ -1069,12 +1081,14 @@ fn verify_third_party_invite(
|
||||||
|
|
||||||
// If any signature in signed matches any public key in the
|
// If any signature in signed matches any public key in the
|
||||||
// m.room.third_party_invite event, allow
|
// m.room.third_party_invite event, allow
|
||||||
|
#[allow(clippy::manual_let_else)]
|
||||||
let tpid_ev =
|
let tpid_ev =
|
||||||
match from_json_str::<RoomThirdPartyInviteEventContent>(current_tpid.content().get()) {
|
match from_json_str::<RoomThirdPartyInviteEventContent>(current_tpid.content().get()) {
|
||||||
| Ok(ev) => ev,
|
| Ok(ev) => ev,
|
||||||
| Err(_) => return false,
|
| Err(_) => return false,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
#[allow(clippy::manual_let_else)]
|
||||||
let decoded_invite_token = match Base64::parse(&tp_id.signed.token) {
|
let decoded_invite_token = match Base64::parse(&tp_id.signed.token) {
|
||||||
| Ok(tok) => tok,
|
| Ok(tok) => tok,
|
||||||
// FIXME: Log a warning?
|
// FIXME: Log a warning?
|
||||||
|
@ -1096,7 +1110,7 @@ fn verify_third_party_invite(
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use ruma_events::{
|
use ruma::events::{
|
||||||
room::{
|
room::{
|
||||||
join_rules::{
|
join_rules::{
|
||||||
AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership,
|
AllowRule, JoinRule, Restricted, RoomJoinRulesEventContent, RoomMembership,
|
||||||
|
@ -1107,7 +1121,7 @@ mod tests {
|
||||||
};
|
};
|
||||||
use serde_json::value::to_raw_value as to_raw_json_value;
|
use serde_json::value::to_raw_value as to_raw_json_value;
|
||||||
|
|
||||||
use crate::{
|
use crate::state_res::{
|
||||||
event_auth::valid_membership_change,
|
event_auth::valid_membership_change,
|
||||||
test_utils::{
|
test_utils::{
|
||||||
alice, charlie, ella, event_id, member_content_ban, member_content_join, room_id,
|
alice, charlie, ella, event_id, member_content_ban, member_content_join, room_id,
|
||||||
|
@ -1145,16 +1159,16 @@ mod tests {
|
||||||
assert!(valid_membership_change(
|
assert!(valid_membership_change(
|
||||||
&RoomVersion::V6,
|
&RoomVersion::V6,
|
||||||
target_user,
|
target_user,
|
||||||
fetch_state(StateEventType::RoomMember, target_user.to_string()),
|
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||||
sender,
|
sender,
|
||||||
fetch_state(StateEventType::RoomMember, sender.to_string()),
|
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||||
&requester,
|
&requester,
|
||||||
None::<PduEvent>,
|
None::<&PduEvent>,
|
||||||
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
|
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||||
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
|
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||||
None,
|
None,
|
||||||
&MembershipState::Leave,
|
&MembershipState::Leave,
|
||||||
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
|
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||||
)
|
)
|
||||||
.unwrap());
|
.unwrap());
|
||||||
}
|
}
|
||||||
|
@ -1188,16 +1202,16 @@ mod tests {
|
||||||
assert!(!valid_membership_change(
|
assert!(!valid_membership_change(
|
||||||
&RoomVersion::V6,
|
&RoomVersion::V6,
|
||||||
target_user,
|
target_user,
|
||||||
fetch_state(StateEventType::RoomMember, target_user.to_string()),
|
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||||
sender,
|
sender,
|
||||||
fetch_state(StateEventType::RoomMember, sender.to_string()),
|
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||||
&requester,
|
&requester,
|
||||||
None::<PduEvent>,
|
None::<&PduEvent>,
|
||||||
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
|
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||||
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
|
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||||
None,
|
None,
|
||||||
&MembershipState::Leave,
|
&MembershipState::Leave,
|
||||||
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
|
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||||
)
|
)
|
||||||
.unwrap());
|
.unwrap());
|
||||||
}
|
}
|
||||||
|
@ -1231,16 +1245,16 @@ mod tests {
|
||||||
assert!(valid_membership_change(
|
assert!(valid_membership_change(
|
||||||
&RoomVersion::V6,
|
&RoomVersion::V6,
|
||||||
target_user,
|
target_user,
|
||||||
fetch_state(StateEventType::RoomMember, target_user.to_string()),
|
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||||
sender,
|
sender,
|
||||||
fetch_state(StateEventType::RoomMember, sender.to_string()),
|
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||||
&requester,
|
&requester,
|
||||||
None::<PduEvent>,
|
None::<&PduEvent>,
|
||||||
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
|
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||||
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
|
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||||
None,
|
None,
|
||||||
&MembershipState::Leave,
|
&MembershipState::Leave,
|
||||||
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
|
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||||
)
|
)
|
||||||
.unwrap());
|
.unwrap());
|
||||||
}
|
}
|
||||||
|
@ -1274,16 +1288,16 @@ mod tests {
|
||||||
assert!(!valid_membership_change(
|
assert!(!valid_membership_change(
|
||||||
&RoomVersion::V6,
|
&RoomVersion::V6,
|
||||||
target_user,
|
target_user,
|
||||||
fetch_state(StateEventType::RoomMember, target_user.to_string()),
|
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||||
sender,
|
sender,
|
||||||
fetch_state(StateEventType::RoomMember, sender.to_string()),
|
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||||
&requester,
|
&requester,
|
||||||
None::<PduEvent>,
|
None::<&PduEvent>,
|
||||||
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
|
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||||
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
|
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||||
None,
|
None,
|
||||||
&MembershipState::Leave,
|
&MembershipState::Leave,
|
||||||
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
|
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||||
)
|
)
|
||||||
.unwrap());
|
.unwrap());
|
||||||
}
|
}
|
||||||
|
@ -1334,32 +1348,32 @@ mod tests {
|
||||||
assert!(valid_membership_change(
|
assert!(valid_membership_change(
|
||||||
&RoomVersion::V9,
|
&RoomVersion::V9,
|
||||||
target_user,
|
target_user,
|
||||||
fetch_state(StateEventType::RoomMember, target_user.to_string()),
|
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||||
sender,
|
sender,
|
||||||
fetch_state(StateEventType::RoomMember, sender.to_string()),
|
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||||
&requester,
|
&requester,
|
||||||
None::<PduEvent>,
|
None::<&PduEvent>,
|
||||||
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
|
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||||
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
|
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||||
Some(alice()),
|
Some(alice()),
|
||||||
&MembershipState::Join,
|
&MembershipState::Join,
|
||||||
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
|
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||||
)
|
)
|
||||||
.unwrap());
|
.unwrap());
|
||||||
|
|
||||||
assert!(!valid_membership_change(
|
assert!(!valid_membership_change(
|
||||||
&RoomVersion::V9,
|
&RoomVersion::V9,
|
||||||
target_user,
|
target_user,
|
||||||
fetch_state(StateEventType::RoomMember, target_user.to_string()),
|
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||||
sender,
|
sender,
|
||||||
fetch_state(StateEventType::RoomMember, sender.to_string()),
|
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||||
&requester,
|
&requester,
|
||||||
None::<PduEvent>,
|
None::<&PduEvent>,
|
||||||
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
|
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||||
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
|
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||||
Some(ella()),
|
Some(ella()),
|
||||||
&MembershipState::Leave,
|
&MembershipState::Leave,
|
||||||
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
|
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||||
)
|
)
|
||||||
.unwrap());
|
.unwrap());
|
||||||
}
|
}
|
||||||
|
@ -1402,16 +1416,16 @@ mod tests {
|
||||||
assert!(valid_membership_change(
|
assert!(valid_membership_change(
|
||||||
&RoomVersion::V7,
|
&RoomVersion::V7,
|
||||||
target_user,
|
target_user,
|
||||||
fetch_state(StateEventType::RoomMember, target_user.to_string()),
|
fetch_state(StateEventType::RoomMember, target_user.as_str().into()).as_ref(),
|
||||||
sender,
|
sender,
|
||||||
fetch_state(StateEventType::RoomMember, sender.to_string()),
|
fetch_state(StateEventType::RoomMember, sender.as_str().into()).as_ref(),
|
||||||
&requester,
|
&requester,
|
||||||
None::<PduEvent>,
|
None::<&PduEvent>,
|
||||||
fetch_state(StateEventType::RoomPowerLevels, "".to_owned()),
|
fetch_state(StateEventType::RoomPowerLevels, "".into()).as_ref(),
|
||||||
fetch_state(StateEventType::RoomJoinRules, "".to_owned()),
|
fetch_state(StateEventType::RoomJoinRules, "".into()).as_ref(),
|
||||||
None,
|
None,
|
||||||
&MembershipState::Leave,
|
&MembershipState::Leave,
|
||||||
fetch_state(StateEventType::RoomCreate, "".to_owned()).unwrap(),
|
&fetch_state(StateEventType::RoomCreate, "".into()).unwrap(),
|
||||||
)
|
)
|
||||||
.unwrap());
|
.unwrap());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,3 +1,5 @@
|
||||||
|
#![cfg_attr(test, allow(warnings))]
|
||||||
|
|
||||||
pub(crate) mod error;
|
pub(crate) mod error;
|
||||||
pub mod event_auth;
|
pub mod event_auth;
|
||||||
mod power_levels;
|
mod power_levels;
|
||||||
|
@ -12,7 +14,7 @@ use std::{
|
||||||
cmp::{Ordering, Reverse},
|
cmp::{Ordering, Reverse},
|
||||||
collections::{BinaryHeap, HashMap, HashSet},
|
collections::{BinaryHeap, HashMap, HashSet},
|
||||||
fmt::Debug,
|
fmt::Debug,
|
||||||
hash::Hash,
|
hash::{BuildHasher, Hash},
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures::{future, stream, Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
|
use futures::{future, stream, Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
|
||||||
|
@ -32,13 +34,13 @@ pub use self::{
|
||||||
room_version::RoomVersion,
|
room_version::RoomVersion,
|
||||||
state_event::Event,
|
state_event::Event,
|
||||||
};
|
};
|
||||||
use crate::{debug, trace, warn};
|
use crate::{debug, pdu::StateKey, trace, warn};
|
||||||
|
|
||||||
/// A mapping of event type and state_key to some value `T`, usually an
|
/// A mapping of event type and state_key to some value `T`, usually an
|
||||||
/// `EventId`.
|
/// `EventId`.
|
||||||
pub type StateMap<T> = HashMap<TypeStateKey, T>;
|
pub type StateMap<T> = HashMap<TypeStateKey, T>;
|
||||||
pub type StateMapItem<T> = (TypeStateKey, T);
|
pub type StateMapItem<T> = (TypeStateKey, T);
|
||||||
pub type TypeStateKey = (StateEventType, String);
|
pub type TypeStateKey = (StateEventType, StateKey);
|
||||||
|
|
||||||
type Result<T, E = Error> = crate::Result<T, E>;
|
type Result<T, E = Error> = crate::Result<T, E>;
|
||||||
|
|
||||||
|
@ -68,10 +70,10 @@ type Result<T, E = Error> = crate::Result<T, E>;
|
||||||
/// event is part of the same room.
|
/// event is part of the same room.
|
||||||
//#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets,
|
//#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets,
|
||||||
//#[tracing::instrument(level event_fetch))]
|
//#[tracing::instrument(level event_fetch))]
|
||||||
pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>(
|
pub async fn resolve<'a, E, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>(
|
||||||
room_version: &RoomVersionId,
|
room_version: &RoomVersionId,
|
||||||
state_sets: impl IntoIterator<IntoIter = SetIter> + Send,
|
state_sets: Sets,
|
||||||
auth_chain_sets: &'a [HashSet<E::Id>],
|
auth_chain_sets: &'a [HashSet<E::Id, Hasher>],
|
||||||
event_fetch: &Fetch,
|
event_fetch: &Fetch,
|
||||||
event_exists: &Exists,
|
event_exists: &Exists,
|
||||||
parallel_fetches: usize,
|
parallel_fetches: usize,
|
||||||
|
@ -81,7 +83,9 @@ where
|
||||||
FetchFut: Future<Output = Option<E>> + Send,
|
FetchFut: Future<Output = Option<E>> + Send,
|
||||||
Exists: Fn(E::Id) -> ExistsFut + Sync,
|
Exists: Fn(E::Id) -> ExistsFut + Sync,
|
||||||
ExistsFut: Future<Output = bool> + Send,
|
ExistsFut: Future<Output = bool> + Send,
|
||||||
|
Sets: IntoIterator<IntoIter = SetIter> + Send,
|
||||||
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone + Send,
|
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone + Send,
|
||||||
|
Hasher: BuildHasher + Send + Sync,
|
||||||
E: Event + Clone + Send + Sync,
|
E: Event + Clone + Send + Sync,
|
||||||
E::Id: Borrow<EventId> + Send + Sync,
|
E::Id: Borrow<EventId> + Send + Sync,
|
||||||
for<'b> &'b E: Send,
|
for<'b> &'b E: Send,
|
||||||
|
@ -178,7 +182,7 @@ where
|
||||||
trace!(list = ?events_to_resolve, "events left to resolve");
|
trace!(list = ?events_to_resolve, "events left to resolve");
|
||||||
|
|
||||||
// This "epochs" power level event
|
// This "epochs" power level event
|
||||||
let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, String::new()));
|
let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, StateKey::new()));
|
||||||
|
|
||||||
debug!(event_id = ?power_event, "power event");
|
debug!(event_id = ?power_event, "power event");
|
||||||
|
|
||||||
|
@ -222,16 +226,17 @@ fn separate<'a, Id>(
|
||||||
where
|
where
|
||||||
Id: Clone + Eq + Hash + 'a,
|
Id: Clone + Eq + Hash + 'a,
|
||||||
{
|
{
|
||||||
let mut state_set_count = 0_usize;
|
let mut state_set_count: usize = 0;
|
||||||
let mut occurrences = HashMap::<_, HashMap<_, _>>::new();
|
let mut occurrences = HashMap::<_, HashMap<_, _>>::new();
|
||||||
|
|
||||||
let state_sets_iter = state_sets_iter.inspect(|_| state_set_count += 1);
|
let state_sets_iter =
|
||||||
|
state_sets_iter.inspect(|_| state_set_count = state_set_count.saturating_add(1));
|
||||||
for (k, v) in state_sets_iter.flatten() {
|
for (k, v) in state_sets_iter.flatten() {
|
||||||
occurrences
|
occurrences
|
||||||
.entry(k)
|
.entry(k)
|
||||||
.or_default()
|
.or_default()
|
||||||
.entry(v)
|
.entry(v)
|
||||||
.and_modify(|x| *x += 1)
|
.and_modify(|x: &mut usize| *x = x.saturating_add(1))
|
||||||
.or_insert(1);
|
.or_insert(1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -246,7 +251,7 @@ where
|
||||||
conflicted_state
|
conflicted_state
|
||||||
.entry((k.0.clone(), k.1.clone()))
|
.entry((k.0.clone(), k.1.clone()))
|
||||||
.and_modify(|x: &mut Vec<_>| x.push(id.clone()))
|
.and_modify(|x: &mut Vec<_>| x.push(id.clone()))
|
||||||
.or_insert(vec![id.clone()]);
|
.or_insert_with(|| vec![id.clone()]);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -255,9 +260,13 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Vec of deduped EventIds that appear in some chains but not others.
|
/// Returns a Vec of deduped EventIds that appear in some chains but not others.
|
||||||
fn get_auth_chain_diff<Id>(auth_chain_sets: &[HashSet<Id>]) -> impl Iterator<Item = Id> + Send
|
#[allow(clippy::arithmetic_side_effects)]
|
||||||
|
fn get_auth_chain_diff<Id, Hasher>(
|
||||||
|
auth_chain_sets: &[HashSet<Id, Hasher>],
|
||||||
|
) -> impl Iterator<Item = Id> + Send
|
||||||
where
|
where
|
||||||
Id: Clone + Eq + Hash + Send,
|
Id: Clone + Eq + Hash + Send,
|
||||||
|
Hasher: BuildHasher + Send + Sync,
|
||||||
{
|
{
|
||||||
let num_sets = auth_chain_sets.len();
|
let num_sets = auth_chain_sets.len();
|
||||||
let mut id_counts: HashMap<Id, usize> = HashMap::new();
|
let mut id_counts: HashMap<Id, usize> = HashMap::new();
|
||||||
|
@ -288,7 +297,7 @@ async fn reverse_topological_power_sort<E, F, Fut>(
|
||||||
where
|
where
|
||||||
F: Fn(E::Id) -> Fut + Sync,
|
F: Fn(E::Id) -> Fut + Sync,
|
||||||
Fut: Future<Output = Option<E>> + Send,
|
Fut: Future<Output = Option<E>> + Send,
|
||||||
E: Event + Send,
|
E: Event + Send + Sync,
|
||||||
E::Id: Borrow<EventId> + Send + Sync,
|
E::Id: Borrow<EventId> + Send + Sync,
|
||||||
{
|
{
|
||||||
debug!("reverse topological sort of power events");
|
debug!("reverse topological sort of power events");
|
||||||
|
@ -337,14 +346,15 @@ where
|
||||||
/// `key_fn` is used as to obtain the power level and age of an event for
|
/// `key_fn` is used as to obtain the power level and age of an event for
|
||||||
/// breaking ties (together with the event ID).
|
/// breaking ties (together with the event ID).
|
||||||
#[tracing::instrument(level = "debug", skip_all)]
|
#[tracing::instrument(level = "debug", skip_all)]
|
||||||
pub async fn lexicographical_topological_sort<Id, F, Fut>(
|
pub async fn lexicographical_topological_sort<Id, F, Fut, Hasher>(
|
||||||
graph: &HashMap<Id, HashSet<Id>>,
|
graph: &HashMap<Id, HashSet<Id, Hasher>>,
|
||||||
key_fn: &F,
|
key_fn: &F,
|
||||||
) -> Result<Vec<Id>>
|
) -> Result<Vec<Id>>
|
||||||
where
|
where
|
||||||
F: Fn(Id) -> Fut + Sync,
|
F: Fn(Id) -> Fut + Sync,
|
||||||
Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send,
|
Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send,
|
||||||
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send,
|
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send + Sync,
|
||||||
|
Hasher: BuildHasher + Default + Clone + Send + Sync,
|
||||||
{
|
{
|
||||||
#[derive(PartialEq, Eq)]
|
#[derive(PartialEq, Eq)]
|
||||||
struct TieBreaker<'a, Id> {
|
struct TieBreaker<'a, Id> {
|
||||||
|
@ -395,7 +405,7 @@ where
|
||||||
|
|
||||||
// The number of events that depend on the given event (the EventId key)
|
// The number of events that depend on the given event (the EventId key)
|
||||||
// How many events reference this event in the DAG as a parent
|
// How many events reference this event in the DAG as a parent
|
||||||
let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new();
|
let mut reverse_graph: HashMap<_, HashSet<_, Hasher>> = HashMap::new();
|
||||||
|
|
||||||
// Vec of nodes that have zero out degree, least recent events.
|
// Vec of nodes that have zero out degree, least recent events.
|
||||||
let mut zero_outdegree = Vec::new();
|
let mut zero_outdegree = Vec::new();
|
||||||
|
@ -727,8 +737,8 @@ async fn get_mainline_depth<E, F, Fut>(
|
||||||
where
|
where
|
||||||
F: Fn(E::Id) -> Fut + Sync,
|
F: Fn(E::Id) -> Fut + Sync,
|
||||||
Fut: Future<Output = Option<E>> + Send,
|
Fut: Future<Output = Option<E>> + Send,
|
||||||
E: Event + Send,
|
E: Event + Send + Sync,
|
||||||
E::Id: Borrow<EventId> + Send,
|
E::Id: Borrow<EventId> + Send + Sync,
|
||||||
{
|
{
|
||||||
while let Some(sort_ev) = event {
|
while let Some(sort_ev) = event {
|
||||||
debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline");
|
debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline");
|
||||||
|
@ -758,10 +768,10 @@ async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
|
||||||
auth_diff: &HashSet<E::Id>,
|
auth_diff: &HashSet<E::Id>,
|
||||||
fetch_event: &F,
|
fetch_event: &F,
|
||||||
) where
|
) where
|
||||||
F: Fn(E::Id) -> Fut,
|
F: Fn(E::Id) -> Fut + Sync,
|
||||||
Fut: Future<Output = Option<E>> + Send,
|
Fut: Future<Output = Option<E>> + Send,
|
||||||
E: Event + Send,
|
E: Event + Send + Sync,
|
||||||
E::Id: Borrow<EventId> + Clone + Send,
|
E::Id: Borrow<EventId> + Clone + Send + Sync,
|
||||||
{
|
{
|
||||||
let mut state = vec![event_id];
|
let mut state = vec![event_id];
|
||||||
while let Some(eid) = state.pop() {
|
while let Some(eid) = state.pop() {
|
||||||
|
@ -788,7 +798,7 @@ where
|
||||||
F: Fn(E::Id) -> Fut + Sync,
|
F: Fn(E::Id) -> Fut + Sync,
|
||||||
Fut: Future<Output = Option<E>> + Send,
|
Fut: Future<Output = Option<E>> + Send,
|
||||||
E: Event + Send,
|
E: Event + Send,
|
||||||
E::Id: Borrow<EventId> + Send,
|
E::Id: Borrow<EventId> + Send + Sync,
|
||||||
{
|
{
|
||||||
match fetch(event_id.clone()).await.as_ref() {
|
match fetch(event_id.clone()).await.as_ref() {
|
||||||
| Some(state) => is_power_event(state),
|
| Some(state) => is_power_event(state),
|
||||||
|
@ -820,18 +830,18 @@ fn is_power_event(event: impl Event) -> bool {
|
||||||
|
|
||||||
/// Convenience trait for adding event type plus state key to state maps.
|
/// Convenience trait for adding event type plus state key to state maps.
|
||||||
pub trait EventTypeExt {
|
pub trait EventTypeExt {
|
||||||
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
|
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey);
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EventTypeExt for StateEventType {
|
impl EventTypeExt for StateEventType {
|
||||||
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
|
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
|
||||||
(self, state_key.into())
|
(self, state_key.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
impl EventTypeExt for TimelineEventType {
|
impl EventTypeExt for TimelineEventType {
|
||||||
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
|
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
|
||||||
(self.to_string().into(), state_key.into())
|
(self.into(), state_key.into())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -839,7 +849,7 @@ impl<T> EventTypeExt for &T
|
||||||
where
|
where
|
||||||
T: EventTypeExt + Clone,
|
T: EventTypeExt + Clone,
|
||||||
{
|
{
|
||||||
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
|
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
|
||||||
self.to_owned().with_state_key(state_key)
|
self.to_owned().with_state_key(state_key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -858,13 +868,11 @@ mod tests {
|
||||||
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
|
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
|
||||||
StateEventType, TimelineEventType,
|
StateEventType, TimelineEventType,
|
||||||
},
|
},
|
||||||
int, uint,
|
int, uint, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
|
||||||
};
|
};
|
||||||
use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId};
|
|
||||||
use serde_json::{json, value::to_raw_value as to_raw_json_value};
|
use serde_json::{json, value::to_raw_value as to_raw_json_value};
|
||||||
use tracing::debug;
|
|
||||||
|
|
||||||
use crate::{
|
use super::{
|
||||||
is_power_event,
|
is_power_event,
|
||||||
room_version::RoomVersion,
|
room_version::RoomVersion,
|
||||||
test_utils::{
|
test_utils::{
|
||||||
|
@ -874,6 +882,7 @@ mod tests {
|
||||||
},
|
},
|
||||||
Event, EventTypeExt, StateMap,
|
Event, EventTypeExt, StateMap,
|
||||||
};
|
};
|
||||||
|
use crate::debug;
|
||||||
|
|
||||||
async fn test_event_sort() {
|
async fn test_event_sort() {
|
||||||
use futures::future::ready;
|
use futures::future::ready;
|
||||||
|
@ -898,11 +907,11 @@ mod tests {
|
||||||
|
|
||||||
let fetcher = |id| ready(events.get(&id).cloned());
|
let fetcher = |id| ready(events.get(&id).cloned());
|
||||||
let sorted_power_events =
|
let sorted_power_events =
|
||||||
crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1)
|
super::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let resolved_power = crate::iterative_auth_check(
|
let resolved_power = super::iterative_auth_check(
|
||||||
&RoomVersion::V6,
|
&RoomVersion::V6,
|
||||||
sorted_power_events.iter(),
|
sorted_power_events.iter(),
|
||||||
HashMap::new(), // unconflicted events
|
HashMap::new(), // unconflicted events
|
||||||
|
@ -918,10 +927,10 @@ mod tests {
|
||||||
events_to_sort.shuffle(&mut rand::thread_rng());
|
events_to_sort.shuffle(&mut rand::thread_rng());
|
||||||
|
|
||||||
let power_level = resolved_power
|
let power_level = resolved_power
|
||||||
.get(&(StateEventType::RoomPowerLevels, "".to_owned()))
|
.get(&(StateEventType::RoomPowerLevels, "".into()))
|
||||||
.cloned();
|
.cloned();
|
||||||
|
|
||||||
let sorted_event_ids = crate::mainline_sort(&events_to_sort, power_level, &fetcher, 1)
|
let sorted_event_ids = super::mainline_sort(&events_to_sort, power_level, &fetcher, 1)
|
||||||
.await
|
.await
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
|
@ -1302,7 +1311,7 @@ mod tests {
|
||||||
})
|
})
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
let resolved = match crate::resolve(
|
let resolved = match super::resolve(
|
||||||
&RoomVersionId::V2,
|
&RoomVersionId::V2,
|
||||||
&state_sets,
|
&state_sets,
|
||||||
&auth_chain,
|
&auth_chain,
|
||||||
|
@ -1333,7 +1342,7 @@ mod tests {
|
||||||
event_id("p") => hashset![event_id("o")],
|
event_id("p") => hashset![event_id("o")],
|
||||||
};
|
};
|
||||||
|
|
||||||
let res = crate::lexicographical_topological_sort(&graph, &|_id| async {
|
let res = super::lexicographical_topological_sort(&graph, &|_id| async {
|
||||||
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
|
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
|
@ -1421,7 +1430,7 @@ mod tests {
|
||||||
|
|
||||||
let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned());
|
let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned());
|
||||||
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some());
|
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some());
|
||||||
let resolved = match crate::resolve(
|
let resolved = match super::resolve(
|
||||||
&RoomVersionId::V6,
|
&RoomVersionId::V6,
|
||||||
&state_sets,
|
&state_sets,
|
||||||
&auth_chain,
|
&auth_chain,
|
||||||
|
@ -1552,7 +1561,7 @@ mod tests {
|
||||||
#[allow(unused_mut)]
|
#[allow(unused_mut)]
|
||||||
let mut x = StateMap::new();
|
let mut x = StateMap::new();
|
||||||
$(
|
$(
|
||||||
x.insert(($kind, $key.to_owned()), $id);
|
x.insert(($kind, $key.into()), $id);
|
||||||
)*
|
)*
|
||||||
x
|
x
|
||||||
}};
|
}};
|
||||||
|
|
|
@ -32,6 +32,7 @@ pub enum StateResolutionVersion {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
|
#[cfg_attr(not(feature = "unstable-exhaustive-types"), non_exhaustive)]
|
||||||
|
#[allow(clippy::struct_excessive_bools)]
|
||||||
pub struct RoomVersion {
|
pub struct RoomVersion {
|
||||||
/// The stability of this room.
|
/// The stability of this room.
|
||||||
pub disposition: RoomDisposition,
|
pub disposition: RoomDisposition,
|
||||||
|
|
|
@ -7,28 +7,28 @@ use std::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use futures_util::future::ready;
|
use futures::future::ready;
|
||||||
use js_int::{int, uint};
|
use ruma::{
|
||||||
use ruma_common::{
|
event_id,
|
||||||
event_id, room_id, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId,
|
events::{
|
||||||
RoomVersionId, ServerSignatures, UserId,
|
|
||||||
};
|
|
||||||
use ruma_events::{
|
|
||||||
pdu::{EventHash, Pdu, RoomV3Pdu},
|
pdu::{EventHash, Pdu, RoomV3Pdu},
|
||||||
room::{
|
room::{
|
||||||
join_rules::{JoinRule, RoomJoinRulesEventContent},
|
join_rules::{JoinRule, RoomJoinRulesEventContent},
|
||||||
member::{MembershipState, RoomMemberEventContent},
|
member::{MembershipState, RoomMemberEventContent},
|
||||||
},
|
},
|
||||||
TimelineEventType,
|
TimelineEventType,
|
||||||
|
},
|
||||||
|
int, room_id, uint, user_id, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId,
|
||||||
|
RoomVersionId, ServerSignatures, UserId,
|
||||||
};
|
};
|
||||||
use serde_json::{
|
use serde_json::{
|
||||||
json,
|
json,
|
||||||
value::{to_raw_value as to_raw_json_value, RawValue as RawJsonValue},
|
value::{to_raw_value as to_raw_json_value, RawValue as RawJsonValue},
|
||||||
};
|
};
|
||||||
use tracing::info;
|
|
||||||
|
|
||||||
pub(crate) use self::event::PduEvent;
|
pub(crate) use self::event::PduEvent;
|
||||||
use crate::{auth_types_for_event, Error, Event, EventTypeExt, Result, StateMap};
|
use super::auth_types_for_event;
|
||||||
|
use crate::{info, Event, EventTypeExt, Result, StateMap};
|
||||||
|
|
||||||
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
|
static SERVER_TIMESTAMP: AtomicU64 = AtomicU64::new(0);
|
||||||
|
|
||||||
|
@ -88,7 +88,7 @@ pub(crate) async fn do_check(
|
||||||
|
|
||||||
// Resolve the current state and add it to the state_at_event map then continue
|
// Resolve the current state and add it to the state_at_event map then continue
|
||||||
// on in "time"
|
// on in "time"
|
||||||
for node in crate::lexicographical_topological_sort(&graph, &|_id| async {
|
for node in super::lexicographical_topological_sort(&graph, &|_id| async {
|
||||||
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
|
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
|
||||||
})
|
})
|
||||||
.await
|
.await
|
||||||
|
@ -135,7 +135,7 @@ pub(crate) async fn do_check(
|
||||||
let event_map = &event_map;
|
let event_map = &event_map;
|
||||||
let fetch = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).cloned());
|
let fetch = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).cloned());
|
||||||
let exists = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).is_some());
|
let exists = |id: <PduEvent as Event>::Id| ready(event_map.get(&id).is_some());
|
||||||
let resolved = crate::resolve(
|
let resolved = super::resolve(
|
||||||
&RoomVersionId::V6,
|
&RoomVersionId::V6,
|
||||||
state_sets,
|
state_sets,
|
||||||
&auth_chain_sets,
|
&auth_chain_sets,
|
||||||
|
@ -223,7 +223,7 @@ pub(crate) async fn do_check(
|
||||||
// Filter out the dummy messages events.
|
// Filter out the dummy messages events.
|
||||||
// These act as points in time where there should be a known state to
|
// These act as points in time where there should be a known state to
|
||||||
// test against.
|
// test against.
|
||||||
&& **k != ("m.room.message".into(), "dummy".to_owned())
|
&& **k != ("m.room.message".into(), "dummy".into())
|
||||||
})
|
})
|
||||||
.map(|(k, v)| (k.clone(), v.clone()))
|
.map(|(k, v)| (k.clone(), v.clone()))
|
||||||
.collect::<StateMap<OwnedEventId>>();
|
.collect::<StateMap<OwnedEventId>>();
|
||||||
|
@ -239,7 +239,8 @@ impl<E: Event> TestStore<E> {
|
||||||
self.0
|
self.0
|
||||||
.get(event_id)
|
.get(event_id)
|
||||||
.cloned()
|
.cloned()
|
||||||
.ok_or_else(|| Error::NotFound(format!("{event_id} not found")))
|
.ok_or_else(|| super::Error::NotFound(format!("{event_id} not found")))
|
||||||
|
.map_err(Into::into)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a Vec of the related auth events to the given `event`.
|
/// Returns a Vec of the related auth events to the given `event`.
|
||||||
|
@ -582,8 +583,10 @@ pub(crate) fn INITIAL_EDGES() -> Vec<OwnedEventId> {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) mod event {
|
pub(crate) mod event {
|
||||||
use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId};
|
use ruma::{
|
||||||
use ruma_events::{pdu::Pdu, TimelineEventType};
|
events::{pdu::Pdu, TimelineEventType},
|
||||||
|
MilliSecondsSinceUnixEpoch, OwnedEventId, RoomId, UserId,
|
||||||
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::value::RawValue as RawJsonValue;
|
use serde_json::value::RawValue as RawJsonValue;
|
||||||
|
|
||||||
|
|
|
@ -133,7 +133,7 @@ pub(super) async fn handle_outlier_pdu<'a>(
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let state_fetch = |ty: &'static StateEventType, sk: &str| {
|
let state_fetch = |ty: &StateEventType, sk: &str| {
|
||||||
let key = (ty.to_owned(), sk.into());
|
let key = (ty.to_owned(), sk.into());
|
||||||
ready(auth_events.get(&key))
|
ready(auth_events.get(&key))
|
||||||
};
|
};
|
||||||
|
|
|
@ -63,7 +63,6 @@ pub async fn resolve_state(
|
||||||
.multi_get_statekey_from_short(shortstatekeys)
|
.multi_get_statekey_from_short(shortstatekeys)
|
||||||
.zip(event_ids)
|
.zip(event_ids)
|
||||||
.ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id)))
|
.ready_filter_map(|(ty_sk, id)| Some((ty_sk.ok()?, id)))
|
||||||
.map(|((ty, sk), id)| ((ty, sk.as_str().to_owned()), id))
|
|
||||||
.collect()
|
.collect()
|
||||||
})
|
})
|
||||||
.map(Ok::<_, Error>)
|
.map(Ok::<_, Error>)
|
||||||
|
|
|
@ -172,7 +172,6 @@ async fn state_at_incoming_fork(
|
||||||
.short
|
.short
|
||||||
.get_statekey_from_short(*k)
|
.get_statekey_from_short(*k)
|
||||||
.map_ok(|(ty, sk)| ((ty, sk), id.clone()))
|
.map_ok(|(ty, sk)| ((ty, sk), id.clone()))
|
||||||
.map_ok(|((ty, sk), id)| ((ty, sk.as_str().to_owned()), id))
|
|
||||||
})
|
})
|
||||||
.ready_filter_map(Result::ok)
|
.ready_filter_map(Result::ok)
|
||||||
.collect()
|
.collect()
|
||||||
|
|
|
@ -3,7 +3,7 @@ use std::{borrow::Borrow, collections::BTreeMap, iter::once, sync::Arc, time::In
|
||||||
use conduwuit::{
|
use conduwuit::{
|
||||||
debug, debug_info, err, implement, state_res, trace,
|
debug, debug_info, err, implement, state_res, trace,
|
||||||
utils::stream::{BroadbandExt, ReadyExt},
|
utils::stream::{BroadbandExt, ReadyExt},
|
||||||
warn, Err, EventTypeExt, PduEvent, Result,
|
warn, Err, EventTypeExt, PduEvent, Result, StateKey,
|
||||||
};
|
};
|
||||||
use futures::{future::ready, FutureExt, StreamExt};
|
use futures::{future::ready, FutureExt, StreamExt};
|
||||||
use ruma::{events::StateEventType, CanonicalJsonValue, RoomId, ServerName};
|
use ruma::{events::StateEventType, CanonicalJsonValue, RoomId, ServerName};
|
||||||
|
@ -71,8 +71,8 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
||||||
debug!("Performing auth check");
|
debug!("Performing auth check");
|
||||||
// 11. Check the auth of the event passes based on the state of the event
|
// 11. Check the auth of the event passes based on the state of the event
|
||||||
let state_fetch_state = &state_at_incoming_event;
|
let state_fetch_state = &state_at_incoming_event;
|
||||||
let state_fetch = |k: &'static StateEventType, s: String| async move {
|
let state_fetch = |k: StateEventType, s: StateKey| async move {
|
||||||
let shortstatekey = self.services.short.get_shortstatekey(k, &s).await.ok()?;
|
let shortstatekey = self.services.short.get_shortstatekey(&k, &s).await.ok()?;
|
||||||
|
|
||||||
let event_id = state_fetch_state.get(&shortstatekey)?;
|
let event_id = state_fetch_state.get(&shortstatekey)?;
|
||||||
self.services.timeline.get_pdu(event_id).await.ok()
|
self.services.timeline.get_pdu(event_id).await.ok()
|
||||||
|
@ -82,7 +82,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
||||||
&room_version,
|
&room_version,
|
||||||
&incoming_pdu,
|
&incoming_pdu,
|
||||||
None, // TODO: third party invite
|
None, // TODO: third party invite
|
||||||
|k, s| state_fetch(k, s.to_owned()),
|
|ty, sk| state_fetch(ty.clone(), sk.into()),
|
||||||
)
|
)
|
||||||
.await
|
.await
|
||||||
.map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?;
|
.map_err(|e| err!(Request(Forbidden("Auth check failed: {e:?}"))))?;
|
||||||
|
@ -104,7 +104,7 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let state_fetch = |k: &'static StateEventType, s: &str| {
|
let state_fetch = |k: &StateEventType, s: &str| {
|
||||||
let key = k.with_state_key(s);
|
let key = k.with_state_key(s);
|
||||||
ready(auth_events.get(&key).cloned())
|
ready(auth_events.get(&key).cloned())
|
||||||
};
|
};
|
||||||
|
|
|
@ -747,7 +747,7 @@ impl Service {
|
||||||
};
|
};
|
||||||
|
|
||||||
let auth_fetch = |k: &StateEventType, s: &str| {
|
let auth_fetch = |k: &StateEventType, s: &str| {
|
||||||
let key = (k.clone(), s.to_owned());
|
let key = (k.clone(), s.into());
|
||||||
ready(auth_events.get(&key))
|
ready(auth_events.get(&key))
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue