optimize further into state-res with SmallString

triage and de-lints for state-res.

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-02-08 01:58:13 +00:00 committed by strawberry
commit f2ca670c3b
15 changed files with 192 additions and 145 deletions
src/core/state_res

View file

@ -1,3 +1,5 @@
#![cfg_attr(test, allow(warnings))]
pub(crate) mod error;
pub mod event_auth;
mod power_levels;
@ -12,7 +14,7 @@ use std::{
cmp::{Ordering, Reverse},
collections::{BinaryHeap, HashMap, HashSet},
fmt::Debug,
hash::Hash,
hash::{BuildHasher, Hash},
};
use futures::{future, stream, Future, FutureExt, StreamExt, TryFutureExt, TryStreamExt};
@ -32,13 +34,13 @@ pub use self::{
room_version::RoomVersion,
state_event::Event,
};
use crate::{debug, trace, warn};
use crate::{debug, pdu::StateKey, trace, warn};
/// A mapping of event type and state_key to some value `T`, usually an
/// `EventId`.
pub type StateMap<T> = HashMap<TypeStateKey, T>;
pub type StateMapItem<T> = (TypeStateKey, T);
pub type TypeStateKey = (StateEventType, String);
pub type TypeStateKey = (StateEventType, StateKey);
type Result<T, E = Error> = crate::Result<T, E>;
@ -68,10 +70,10 @@ type Result<T, E = Error> = crate::Result<T, E>;
/// event is part of the same room.
//#[tracing::instrument(level = "debug", skip(state_sets, auth_chain_sets,
//#[tracing::instrument(level event_fetch))]
pub async fn resolve<'a, E, SetIter, Fetch, FetchFut, Exists, ExistsFut>(
pub async fn resolve<'a, E, Sets, SetIter, Hasher, Fetch, FetchFut, Exists, ExistsFut>(
room_version: &RoomVersionId,
state_sets: impl IntoIterator<IntoIter = SetIter> + Send,
auth_chain_sets: &'a [HashSet<E::Id>],
state_sets: Sets,
auth_chain_sets: &'a [HashSet<E::Id, Hasher>],
event_fetch: &Fetch,
event_exists: &Exists,
parallel_fetches: usize,
@ -81,7 +83,9 @@ where
FetchFut: Future<Output = Option<E>> + Send,
Exists: Fn(E::Id) -> ExistsFut + Sync,
ExistsFut: Future<Output = bool> + Send,
Sets: IntoIterator<IntoIter = SetIter> + Send,
SetIter: Iterator<Item = &'a StateMap<E::Id>> + Clone + Send,
Hasher: BuildHasher + Send + Sync,
E: Event + Clone + Send + Sync,
E::Id: Borrow<EventId> + Send + Sync,
for<'b> &'b E: Send,
@ -178,7 +182,7 @@ where
trace!(list = ?events_to_resolve, "events left to resolve");
// This "epochs" power level event
let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, String::new()));
let power_event = resolved_control.get(&(StateEventType::RoomPowerLevels, StateKey::new()));
debug!(event_id = ?power_event, "power event");
@ -222,16 +226,17 @@ fn separate<'a, Id>(
where
Id: Clone + Eq + Hash + 'a,
{
let mut state_set_count = 0_usize;
let mut state_set_count: usize = 0;
let mut occurrences = HashMap::<_, HashMap<_, _>>::new();
let state_sets_iter = state_sets_iter.inspect(|_| state_set_count += 1);
let state_sets_iter =
state_sets_iter.inspect(|_| state_set_count = state_set_count.saturating_add(1));
for (k, v) in state_sets_iter.flatten() {
occurrences
.entry(k)
.or_default()
.entry(v)
.and_modify(|x| *x += 1)
.and_modify(|x: &mut usize| *x = x.saturating_add(1))
.or_insert(1);
}
@ -246,7 +251,7 @@ where
conflicted_state
.entry((k.0.clone(), k.1.clone()))
.and_modify(|x: &mut Vec<_>| x.push(id.clone()))
.or_insert(vec![id.clone()]);
.or_insert_with(|| vec![id.clone()]);
}
}
}
@ -255,9 +260,13 @@ where
}
/// Returns a Vec of deduped EventIds that appear in some chains but not others.
fn get_auth_chain_diff<Id>(auth_chain_sets: &[HashSet<Id>]) -> impl Iterator<Item = Id> + Send
#[allow(clippy::arithmetic_side_effects)]
fn get_auth_chain_diff<Id, Hasher>(
auth_chain_sets: &[HashSet<Id, Hasher>],
) -> impl Iterator<Item = Id> + Send
where
Id: Clone + Eq + Hash + Send,
Hasher: BuildHasher + Send + Sync,
{
let num_sets = auth_chain_sets.len();
let mut id_counts: HashMap<Id, usize> = HashMap::new();
@ -288,7 +297,7 @@ async fn reverse_topological_power_sort<E, F, Fut>(
where
F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E: Event + Send + Sync,
E::Id: Borrow<EventId> + Send + Sync,
{
debug!("reverse topological sort of power events");
@ -337,14 +346,15 @@ where
/// `key_fn` is used as to obtain the power level and age of an event for
/// breaking ties (together with the event ID).
#[tracing::instrument(level = "debug", skip_all)]
pub async fn lexicographical_topological_sort<Id, F, Fut>(
graph: &HashMap<Id, HashSet<Id>>,
pub async fn lexicographical_topological_sort<Id, F, Fut, Hasher>(
graph: &HashMap<Id, HashSet<Id, Hasher>>,
key_fn: &F,
) -> Result<Vec<Id>>
where
F: Fn(Id) -> Fut + Sync,
Fut: Future<Output = Result<(Int, MilliSecondsSinceUnixEpoch)>> + Send,
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send,
Id: Borrow<EventId> + Clone + Eq + Hash + Ord + Send + Sync,
Hasher: BuildHasher + Default + Clone + Send + Sync,
{
#[derive(PartialEq, Eq)]
struct TieBreaker<'a, Id> {
@ -395,7 +405,7 @@ where
// The number of events that depend on the given event (the EventId key)
// How many events reference this event in the DAG as a parent
let mut reverse_graph: HashMap<_, HashSet<_>> = HashMap::new();
let mut reverse_graph: HashMap<_, HashSet<_, Hasher>> = HashMap::new();
// Vec of nodes that have zero out degree, least recent events.
let mut zero_outdegree = Vec::new();
@ -727,8 +737,8 @@ async fn get_mainline_depth<E, F, Fut>(
where
F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send,
E: Event + Send + Sync,
E::Id: Borrow<EventId> + Send + Sync,
{
while let Some(sort_ev) = event {
debug!(event_id = sort_ev.event_id().borrow().as_str(), "mainline");
@ -758,10 +768,10 @@ async fn add_event_and_auth_chain_to_graph<E, F, Fut>(
auth_diff: &HashSet<E::Id>,
fetch_event: &F,
) where
F: Fn(E::Id) -> Fut,
F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Clone + Send,
E: Event + Send + Sync,
E::Id: Borrow<EventId> + Clone + Send + Sync,
{
let mut state = vec![event_id];
while let Some(eid) = state.pop() {
@ -788,7 +798,7 @@ where
F: Fn(E::Id) -> Fut + Sync,
Fut: Future<Output = Option<E>> + Send,
E: Event + Send,
E::Id: Borrow<EventId> + Send,
E::Id: Borrow<EventId> + Send + Sync,
{
match fetch(event_id.clone()).await.as_ref() {
| Some(state) => is_power_event(state),
@ -820,18 +830,18 @@ fn is_power_event(event: impl Event) -> bool {
/// Convenience trait for adding event type plus state key to state maps.
pub trait EventTypeExt {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String);
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey);
}
impl EventTypeExt for StateEventType {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
(self, state_key.into())
}
}
impl EventTypeExt for TimelineEventType {
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
(self.to_string().into(), state_key.into())
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
(self.into(), state_key.into())
}
}
@ -839,7 +849,7 @@ impl<T> EventTypeExt for &T
where
T: EventTypeExt + Clone,
{
fn with_state_key(self, state_key: impl Into<String>) -> (StateEventType, String) {
fn with_state_key(self, state_key: impl Into<StateKey>) -> (StateEventType, StateKey) {
self.to_owned().with_state_key(state_key)
}
}
@ -858,13 +868,11 @@ mod tests {
room::join_rules::{JoinRule, RoomJoinRulesEventContent},
StateEventType, TimelineEventType,
},
int, uint,
int, uint, MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId,
};
use ruma_common::{MilliSecondsSinceUnixEpoch, OwnedEventId, RoomVersionId};
use serde_json::{json, value::to_raw_value as to_raw_json_value};
use tracing::debug;
use crate::{
use super::{
is_power_event,
room_version::RoomVersion,
test_utils::{
@ -874,6 +882,7 @@ mod tests {
},
Event, EventTypeExt, StateMap,
};
use crate::debug;
async fn test_event_sort() {
use futures::future::ready;
@ -898,11 +907,11 @@ mod tests {
let fetcher = |id| ready(events.get(&id).cloned());
let sorted_power_events =
crate::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1)
super::reverse_topological_power_sort(power_events, &auth_chain, &fetcher, 1)
.await
.unwrap();
let resolved_power = crate::iterative_auth_check(
let resolved_power = super::iterative_auth_check(
&RoomVersion::V6,
sorted_power_events.iter(),
HashMap::new(), // unconflicted events
@ -918,10 +927,10 @@ mod tests {
events_to_sort.shuffle(&mut rand::thread_rng());
let power_level = resolved_power
.get(&(StateEventType::RoomPowerLevels, "".to_owned()))
.get(&(StateEventType::RoomPowerLevels, "".into()))
.cloned();
let sorted_event_ids = crate::mainline_sort(&events_to_sort, power_level, &fetcher, 1)
let sorted_event_ids = super::mainline_sort(&events_to_sort, power_level, &fetcher, 1)
.await
.unwrap();
@ -1302,7 +1311,7 @@ mod tests {
})
.collect();
let resolved = match crate::resolve(
let resolved = match super::resolve(
&RoomVersionId::V2,
&state_sets,
&auth_chain,
@ -1333,7 +1342,7 @@ mod tests {
event_id("p") => hashset![event_id("o")],
};
let res = crate::lexicographical_topological_sort(&graph, &|_id| async {
let res = super::lexicographical_topological_sort(&graph, &|_id| async {
Ok((int!(0), MilliSecondsSinceUnixEpoch(uint!(0))))
})
.await
@ -1421,7 +1430,7 @@ mod tests {
let fetcher = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).cloned());
let exists = |id: <PduEvent as Event>::Id| ready(ev_map.get(&id).is_some());
let resolved = match crate::resolve(
let resolved = match super::resolve(
&RoomVersionId::V6,
&state_sets,
&auth_chain,
@ -1552,7 +1561,7 @@ mod tests {
#[allow(unused_mut)]
let mut x = StateMap::new();
$(
x.insert(($kind, $key.to_owned()), $id);
x.insert(($kind, $key.into()), $id);
)*
x
}};