refactor incoming extremities retention; broad filter, single pass

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-01-30 04:39:24 +00:00
parent 31c2968bb2
commit 1a8482b3b4
4 changed files with 74 additions and 68 deletions

View file

@ -1,6 +1,7 @@
use std::{ use std::{
borrow::Borrow, borrow::Borrow,
collections::{BTreeMap, HashMap, HashSet}, collections::{BTreeMap, HashMap, HashSet},
iter::once,
net::IpAddr, net::IpAddr,
sync::Arc, sync::Arc,
}; };
@ -1216,7 +1217,7 @@ async fn join_room_by_id_helper_remote(
.append_pdu( .append_pdu(
&parsed_join_pdu, &parsed_join_pdu,
join_event, join_event,
vec![(*parsed_join_pdu.event_id).to_owned()], once(parsed_join_pdu.event_id.borrow()),
&state_lock, &state_lock,
) )
.await?; .await?;
@ -2195,7 +2196,7 @@ async fn knock_room_helper_local(
.append_pdu( .append_pdu(
&parsed_knock_pdu, &parsed_knock_pdu,
knock_event, knock_event,
vec![(*parsed_knock_pdu.event_id).to_owned()], once(parsed_knock_pdu.event_id.borrow()),
&state_lock, &state_lock,
) )
.await?; .await?;
@ -2394,7 +2395,7 @@ async fn knock_room_helper_remote(
.append_pdu( .append_pdu(
&parsed_knock_pdu, &parsed_knock_pdu,
knock_event, knock_event,
vec![(*parsed_knock_pdu.event_id).to_owned()], once(parsed_knock_pdu.event_id.borrow()),
&state_lock, &state_lock,
) )
.await?; .await?;

View file

@ -1,14 +1,18 @@
use std::{ use std::{
borrow::Borrow, borrow::Borrow,
collections::{BTreeMap, HashSet}, collections::{BTreeMap, HashSet},
iter::once,
sync::Arc, sync::Arc,
time::Instant, time::Instant,
}; };
use conduwuit::{debug, debug_info, err, implement, trace, warn, Err, Error, PduEvent, Result}; use conduwuit::{
use futures::{future::ready, StreamExt}; debug, debug_info, err, implement, trace,
utils::stream::{BroadbandExt, ReadyExt},
warn, Err, PduEvent, Result,
};
use futures::{future::ready, FutureExt, StreamExt};
use ruma::{ use ruma::{
api::client::error::ErrorKind,
events::{room::redaction::RoomRedactionEventContent, StateEventType, TimelineEventType}, events::{room::redaction::RoomRedactionEventContent, StateEventType, TimelineEventType},
state_res::{self, EventTypeExt}, state_res::{self, EventTypeExt},
CanonicalJsonValue, RoomId, RoomVersionId, ServerName, CanonicalJsonValue, RoomId, RoomVersionId, ServerName,
@ -174,42 +178,34 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
// Now we calculate the set of extremities this room has after the incoming // Now we calculate the set of extremities this room has after the incoming
// event has been applied. We start with the previous extremities (aka leaves) // event has been applied. We start with the previous extremities (aka leaves)
trace!("Calculating extremities"); trace!("Calculating extremities");
let mut extremities: HashSet<_> = self let extremities: Vec<_> = self
.services .services
.state .state
.get_forward_extremities(room_id) .get_forward_extremities(room_id)
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
.ready_filter(|event_id| {
// Remove any that are referenced by this incoming event's prev_events
!incoming_pdu.prev_events.contains(event_id)
})
.broad_filter_map(|event_id| async move {
// Only keep those extremities were not referenced yet
self.services
.pdu_metadata
.is_event_referenced(room_id, &event_id)
.await
.eq(&false)
.then_some(event_id)
})
.collect() .collect()
.await; .await;
// Remove any forward extremities that are referenced by this incoming event's debug!(
// prev_events "Retained {} extremities checked against {} prev_events",
trace!(
"Calculated {} extremities; checking against {} prev_events",
extremities.len(), extremities.len(),
incoming_pdu.prev_events.len() incoming_pdu.prev_events.len()
); );
for prev_event in &incoming_pdu.prev_events {
extremities.remove(&(**prev_event));
}
// Only keep those extremities were not referenced yet let state_ids_compressed: Arc<HashSet<_>> = self
let mut retained = HashSet::new();
for id in &extremities {
if !self
.services
.pdu_metadata
.is_event_referenced(room_id, id)
.await
{
retained.insert(id.clone());
}
}
extremities.retain(|id| retained.contains(id));
debug!("Retained {} extremities. Compressing state", extremities.len());
let state_ids_compressed: HashSet<_> = self
.services .services
.state_compressor .state_compressor
.compress_state_events( .compress_state_events(
@ -218,10 +214,9 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
.map(|(ssk, eid)| (ssk, eid.borrow())), .map(|(ssk, eid)| (ssk, eid.borrow())),
) )
.collect() .collect()
.map(Arc::new)
.await; .await;
let state_ids_compressed = Arc::new(state_ids_compressed);
if incoming_pdu.state_key.is_some() { if incoming_pdu.state_key.is_some() {
debug!("Event is a state-event. Deriving new room state"); debug!("Event is a state-event. Deriving new room state");
@ -260,12 +255,14 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
// if not soft fail it // if not soft fail it
if soft_fail { if soft_fail {
debug!("Soft failing event"); debug!("Soft failing event");
let extremities = extremities.iter().map(Borrow::borrow);
self.services self.services
.timeline .timeline
.append_incoming_pdu( .append_incoming_pdu(
&incoming_pdu, &incoming_pdu,
val, val,
extremities.iter().map(|e| (**e).to_owned()).collect(), extremities,
state_ids_compressed, state_ids_compressed,
soft_fail, soft_fail,
&state_lock, &state_lock,
@ -273,27 +270,30 @@ pub(super) async fn upgrade_outlier_to_timeline_pdu(
.await?; .await?;
// Soft fail, we keep the event as an outlier but don't add it to the timeline // Soft fail, we keep the event as an outlier but don't add it to the timeline
warn!("Event was soft failed: {incoming_pdu:?}");
self.services self.services
.pdu_metadata .pdu_metadata
.mark_event_soft_failed(&incoming_pdu.event_id); .mark_event_soft_failed(&incoming_pdu.event_id);
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); warn!("Event was soft failed: {incoming_pdu:?}");
return Err!(Request(InvalidParam("Event has been soft failed")));
} }
trace!("Appending pdu to timeline");
extremities.insert(incoming_pdu.event_id.clone());
// Now that the event has passed all auth it is added into the timeline. // Now that the event has passed all auth it is added into the timeline.
// We use the `state_at_event` instead of `state_after` so we accurately // We use the `state_at_event` instead of `state_after` so we accurately
// represent the state for this event. // represent the state for this event.
trace!("Appending pdu to timeline");
let extremities = extremities
.iter()
.map(Borrow::borrow)
.chain(once(incoming_pdu.event_id.borrow()));
let pdu_id = self let pdu_id = self
.services .services
.timeline .timeline
.append_incoming_pdu( .append_incoming_pdu(
&incoming_pdu, &incoming_pdu,
val, val,
extremities.into_iter().collect(), extremities,
state_ids_compressed, state_ids_compressed,
soft_fail, soft_fail,
&state_lock, &state_lock,

View file

@ -398,13 +398,14 @@ impl Service {
.ignore_err() .ignore_err()
} }
pub async fn set_forward_extremities( pub async fn set_forward_extremities<'a, I>(
&self, &'a self,
room_id: &RoomId, room_id: &'a RoomId,
event_ids: Vec<OwnedEventId>, event_ids: I,
_state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room _state_lock: &'a RoomMutexGuard,
* state mutex */ ) where
) { I: Iterator<Item = &'a EventId> + Send + 'a,
{
let prefix = (room_id, Interfix); let prefix = (room_id, Interfix);
self.db self.db
.roomid_pduleaves .roomid_pduleaves
@ -413,7 +414,7 @@ impl Service {
.ready_for_each(|key| self.db.roomid_pduleaves.remove(key)) .ready_for_each(|key| self.db.roomid_pduleaves.remove(key))
.await; .await;
for event_id in &event_ids { for event_id in event_ids {
let key = (room_id, event_id); let key = (room_id, event_id);
self.db.roomid_pduleaves.put_raw(key, event_id); self.db.roomid_pduleaves.put_raw(key, event_id);
} }

View file

@ -1,6 +1,7 @@
mod data; mod data;
use std::{ use std::{
borrow::Borrow,
cmp, cmp,
collections::{BTreeMap, HashSet}, collections::{BTreeMap, HashSet},
fmt::Write, fmt::Write,
@ -260,14 +261,16 @@ impl Service {
/// ///
/// Returns pdu id /// Returns pdu id
#[tracing::instrument(level = "debug", skip_all)] #[tracing::instrument(level = "debug", skip_all)]
pub async fn append_pdu( pub async fn append_pdu<'a, Leafs>(
&self, &'a self,
pdu: &PduEvent, pdu: &'a PduEvent,
mut pdu_json: CanonicalJsonObject, mut pdu_json: CanonicalJsonObject,
leaves: Vec<OwnedEventId>, leafs: Leafs,
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state state_lock: &'a RoomMutexGuard,
* mutex */ ) -> Result<RawPduId>
) -> Result<RawPduId> { where
Leafs: Iterator<Item = &'a EventId> + Send + 'a,
{
// Coalesce database writes for the remainder of this scope. // Coalesce database writes for the remainder of this scope.
let _cork = self.db.db.cork_and_flush(); let _cork = self.db.db.cork_and_flush();
@ -335,7 +338,7 @@ impl Service {
self.services self.services
.state .state
.set_forward_extremities(&pdu.room_id, leaves, state_lock) .set_forward_extremities(&pdu.room_id, leafs, state_lock)
.await; .await;
let insert_lock = self.mutex_insert.lock(&pdu.room_id).await; let insert_lock = self.mutex_insert.lock(&pdu.room_id).await;
@ -819,8 +822,7 @@ impl Service {
pdu_builder: PduBuilder, pdu_builder: PduBuilder,
sender: &UserId, sender: &UserId,
room_id: &RoomId, room_id: &RoomId,
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state state_lock: &RoomMutexGuard,
* mutex */
) -> Result<OwnedEventId> { ) -> Result<OwnedEventId> {
let (pdu, pdu_json) = self let (pdu, pdu_json) = self
.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock) .create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)
@ -896,7 +898,7 @@ impl Service {
pdu_json, pdu_json,
// Since this PDU references all pdu_leaves we can update the leaves // Since this PDU references all pdu_leaves we can update the leaves
// of the room // of the room
vec![(*pdu.event_id).to_owned()], once(pdu.event_id.borrow()),
state_lock, state_lock,
) )
.boxed() .boxed()
@ -943,16 +945,18 @@ impl Service {
/// Append the incoming event setting the state snapshot to the state from /// Append the incoming event setting the state snapshot to the state from
/// the server that sent the event. /// the server that sent the event.
#[tracing::instrument(level = "debug", skip_all)] #[tracing::instrument(level = "debug", skip_all)]
pub async fn append_incoming_pdu( pub async fn append_incoming_pdu<'a, Leafs>(
&self, &'a self,
pdu: &PduEvent, pdu: &'a PduEvent,
pdu_json: CanonicalJsonObject, pdu_json: CanonicalJsonObject,
new_room_leaves: Vec<OwnedEventId>, new_room_leafs: Leafs,
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
soft_fail: bool, soft_fail: bool,
state_lock: &RoomMutexGuard, /* Take mutex guard to make sure users get the room state state_lock: &'a RoomMutexGuard,
* mutex */ ) -> Result<Option<RawPduId>>
) -> Result<Option<RawPduId>> { where
Leafs: Iterator<Item = &'a EventId> + Send + 'a,
{
// We append to state before appending the pdu, so we don't have a moment in // We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't // time with the pdu without it's state. This is okay because append_pdu can't
// fail. // fail.
@ -968,14 +972,14 @@ impl Service {
self.services self.services
.state .state
.set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock) .set_forward_extremities(&pdu.room_id, new_room_leafs, state_lock)
.await; .await;
return Ok(None); return Ok(None);
} }
let pdu_id = self let pdu_id = self
.append_pdu(pdu, pdu_json, new_room_leaves, state_lock) .append_pdu(pdu, pdu_json, new_room_leafs, state_lock)
.await?; .await?;
Ok(Some(pdu_id)) Ok(Some(pdu_id))