simplify get_pdu() interface; eliminate unconditional Arc

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-11-29 08:26:27 +00:00
parent 58be22e695
commit 6175e72f1c
16 changed files with 51 additions and 52 deletions

View file

@ -266,15 +266,15 @@ pub(super) async fn get_remote_pdu(
#[admin_command] #[admin_command]
pub(super) async fn get_room_state(&self, room: OwnedRoomOrAliasId) -> Result<RoomMessageEventContent> { pub(super) async fn get_room_state(&self, room: OwnedRoomOrAliasId) -> Result<RoomMessageEventContent> {
let room_id = self.services.rooms.alias.resolve(&room).await?; let room_id = self.services.rooms.alias.resolve(&room).await?;
let room_state = self let room_state: Vec<_> = self
.services .services
.rooms .rooms
.state_accessor .state_accessor
.room_state_full(&room_id) .room_state_full(&room_id)
.await? .await?
.values() .values()
.map(|pdu| pdu.to_state_event()) .map(PduEvent::to_state_event)
.collect::<Vec<_>>(); .collect();
if room_state.is_empty() { if room_state.is_empty() {
return Ok(RoomMessageEventContent::text_plain( return Ok(RoomMessageEventContent::text_plain(

View file

@ -103,7 +103,7 @@ pub(crate) async fn get_context_route(
.collect() .collect()
.await; .await;
let lazy = once(&(base_token, (*base_event).clone())) let lazy = once(&(base_token, base_event.clone()))
.chain(events_before.iter()) .chain(events_before.iter())
.chain(events_after.iter()) .chain(events_after.iter())
.stream() .stream()

View file

@ -137,7 +137,7 @@ pub(crate) async fn report_event_route(
/// check if reporting user is in the reporting room /// check if reporting user is in the reporting room
async fn is_event_report_valid( async fn is_event_report_valid(
services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: Option<&String>, services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: Option<&String>,
score: Option<ruma::Int>, pdu: &std::sync::Arc<PduEvent>, score: Option<ruma::Int>, pdu: &PduEvent,
) -> Result<()> { ) -> Result<()> {
debug_info!("Checking if report from user {sender_user} for event {event_id} in room {room_id} is valid"); debug_info!("Checking if report from user {sender_user} for event {event_id} in room {room_id} is valid");

View file

@ -18,7 +18,7 @@ pub(crate) async fn get_room_event_route(
event: services event: services
.rooms .rooms
.timeline .timeline
.get_pdu_owned(&body.event_id) .get_pdu(&body.event_id)
.map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id)))) .map_err(|_| err!(Request(NotFound("Event {} not found.", &body.event_id))))
.and_then(|event| async move { .and_then(|event| async move {
services services

View file

@ -181,11 +181,7 @@ async fn procure_room_state(services: &Services, room_id: &RoomId) -> Result<Roo
.room_state_full(room_id) .room_state_full(room_id)
.await?; .await?;
let state_events = state_map let state_events = state_map.values().map(PduEvent::to_state_event).collect();
.values()
.map(AsRef::as_ref)
.map(PduEvent::to_state_event)
.collect();
Ok(state_events) Ok(state_events)
} }

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::State; use axum::extract::State;
use conduit::{err, pdu::PduBuilder, utils::BoolExt, Err, Error, Result}; use conduit::{err, pdu::PduBuilder, utils::BoolExt, Err, Error, PduEvent, Result};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -97,7 +97,7 @@ pub(crate) async fn get_state_events_route(
.room_state_full(&body.room_id) .room_state_full(&body.room_id)
.await? .await?
.values() .values()
.map(|pdu| pdu.to_state_event()) .map(PduEvent::to_state_event)
.collect(), .collect(),
}) })
} }

View file

@ -1021,7 +1021,7 @@ async fn load_joined_room(
state: RoomState { state: RoomState {
events: state_events events: state_events
.iter() .iter()
.map(|pdu| pdu.to_sync_state_event()) .map(PduEvent::to_sync_state_event)
.collect(), .collect(),
}, },
ephemeral: Ephemeral { ephemeral: Ephemeral {

View file

@ -7,6 +7,7 @@ use std::{
use conduit::{ use conduit::{
debug, debug_error, implement, info, pdu, trace, utils::math::continue_exponential_backoff_secs, warn, PduEvent, debug, debug_error, implement, info, pdu, trace, utils::math::continue_exponential_backoff_secs, warn, PduEvent,
}; };
use futures::TryFutureExt;
use ruma::{api::federation::event::get_event, CanonicalJsonValue, EventId, RoomId, RoomVersionId, ServerName}; use ruma::{api::federation::event::get_event, CanonicalJsonValue, EventId, RoomId, RoomVersionId, ServerName};
/// Find the event and auth it. Once the event is validated (steps 1 - 8) /// Find the event and auth it. Once the event is validated (steps 1 - 8)
@ -42,7 +43,7 @@ pub(super) async fn fetch_and_handle_outliers<'a>(
// a. Look in the main timeline (pduid_pdu tree) // a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree // b. Look at outlier pdu tree
// (get_pdu_json checks both) // (get_pdu_json checks both)
if let Ok(local_pdu) = self.services.timeline.get_pdu(id).await { if let Ok(local_pdu) = self.services.timeline.get_pdu(id).map_ok(Arc::new).await {
trace!("Found {id} in db"); trace!("Found {id} in db");
events_with_auth_events.push((id, Some(local_pdu), vec![])); events_with_auth_events.push((id, Some(local_pdu), vec![]));
continue; continue;

View file

@ -1,10 +1,11 @@
use std::{ use std::{
collections::{hash_map, BTreeMap}, collections::{hash_map, BTreeMap},
sync::Arc,
time::Instant, time::Instant,
}; };
use conduit::{debug, err, implement, warn, Error, Result}; use conduit::{debug, err, implement, warn, Error, Result};
use futures::FutureExt; use futures::{FutureExt, TryFutureExt};
use ruma::{ use ruma::{
api::client::error::ErrorKind, events::StateEventType, CanonicalJsonValue, EventId, RoomId, ServerName, UserId, api::client::error::ErrorKind, events::StateEventType, CanonicalJsonValue, EventId, RoomId, ServerName, UserId,
}; };
@ -79,6 +80,7 @@ pub async fn handle_incoming_pdu<'a>(
.services .services
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "") .room_state_get(room_id, &StateEventType::RoomCreate, "")
.map_ok(Arc::new)
.await?; .await?;
// Procure the room version // Procure the room version

View file

@ -4,7 +4,7 @@ use std::{
}; };
use conduit::{debug, debug_info, err, implement, trace, warn, Err, Error, PduEvent, Result}; use conduit::{debug, debug_info, err, implement, trace, warn, Err, Error, PduEvent, Result};
use futures::future::ready; use futures::{future::ready, TryFutureExt};
use ruma::{ use ruma::{
api::client::error::ErrorKind, api::client::error::ErrorKind,
events::StateEventType, events::StateEventType,
@ -94,7 +94,7 @@ pub(super) async fn handle_outlier_pdu<'a>(
// Build map of auth events // Build map of auth events
let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len()); let mut auth_events = HashMap::with_capacity(incoming_pdu.auth_events.len());
for id in &incoming_pdu.auth_events { for id in &incoming_pdu.auth_events {
let Ok(auth_event) = self.services.timeline.get_pdu(id).await else { let Ok(auth_event) = self.services.timeline.get_pdu(id).map_ok(Arc::new).await else {
warn!("Could not find auth event {id}"); warn!("Could not find auth event {id}");
continue; continue;
}; };

View file

@ -17,7 +17,11 @@ use std::{
time::Instant, time::Instant,
}; };
use conduit::{utils::MutexMap, Err, PduEvent, Result, Server}; use conduit::{
utils::{MutexMap, TryFutureExtExt},
Err, PduEvent, Result, Server,
};
use futures::TryFutureExt;
use ruma::{ use ruma::{
events::room::create::RoomCreateEventContent, state_res::RoomVersion, EventId, OwnedEventId, OwnedRoomId, RoomId, events::room::create::RoomCreateEventContent, state_res::RoomVersion, EventId, OwnedEventId, OwnedRoomId, RoomId,
RoomVersionId, RoomVersionId,
@ -94,7 +98,12 @@ impl Service {
async fn event_exists(&self, event_id: Arc<EventId>) -> bool { self.services.timeline.pdu_exists(&event_id).await } async fn event_exists(&self, event_id: Arc<EventId>) -> bool { self.services.timeline.pdu_exists(&event_id).await }
async fn event_fetch(&self, event_id: Arc<EventId>) -> Option<Arc<PduEvent>> { async fn event_fetch(&self, event_id: Arc<EventId>) -> Option<Arc<PduEvent>> {
self.services.timeline.get_pdu(&event_id).await.ok() self.services
.timeline
.get_pdu(&event_id)
.map_ok(Arc::new)
.ok()
.await
} }
} }

View file

@ -445,7 +445,7 @@ impl Service {
.into_iter() .into_iter()
.map(at!(0)) .map(at!(0))
.zip(auth_pdus.into_iter()) .zip(auth_pdus.into_iter())
.filter_map(|((event_type, state_key), pdu)| Some(((event_type, state_key), pdu.ok()?))) .filter_map(|((event_type, state_key), pdu)| Some(((event_type, state_key), pdu.ok()?.into())))
.collect(); .collect();
Ok(auth_pdus) Ok(auth_pdus)

View file

@ -46,7 +46,7 @@ impl Data {
pub(super) async fn state_full( pub(super) async fn state_full(
&self, shortstatehash: ShortStateHash, &self, shortstatehash: ShortStateHash,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), PduEvent>> {
let state = self let state = self
.state_full_pdus(shortstatehash) .state_full_pdus(shortstatehash)
.await? .await?
@ -57,24 +57,27 @@ impl Data {
Ok(state) Ok(state)
} }
pub(super) async fn state_full_pdus(&self, shortstatehash: ShortStateHash) -> Result<Vec<Arc<PduEvent>>> { pub(super) async fn state_full_pdus(&self, shortstatehash: ShortStateHash) -> Result<Vec<PduEvent>> {
let short_ids = self let short_ids = self
.state_full_shortids(shortstatehash) .state_full_shortids(shortstatehash)
.await? .await?
.into_iter() .into_iter()
.map(at!(1)); .map(at!(1));
let event_ids = self let event_ids: Vec<OwnedEventId> = self
.services .services
.short .short
.multi_get_eventid_from_short(short_ids) .multi_get_eventid_from_short(short_ids)
.await; .await
.into_iter()
.filter_map(Result::ok)
.collect();
let full_pdus = event_ids let full_pdus = event_ids
.into_iter() .iter()
.stream() .stream()
.then(|event_id| self.services.timeline.get_pdu(event_id))
.ready_filter_map(Result::ok) .ready_filter_map(Result::ok)
.filter_map(|event_id: OwnedEventId| async move { self.services.timeline.get_pdu(&event_id).await.ok() })
.collect() .collect()
.await; .await;
@ -157,7 +160,7 @@ impl Data {
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn state_get( pub(super) async fn state_get(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<Arc<PduEvent>> { ) -> Result<PduEvent> {
self.state_get_id(shortstatehash, event_type, state_key) self.state_get_id(shortstatehash, event_type, state_key)
.and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await }) .and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await })
.await .await
@ -181,7 +184,7 @@ impl Data {
/// Returns the full room state. /// Returns the full room state.
pub(super) async fn room_state_full( pub(super) async fn room_state_full(
&self, room_id: &RoomId, &self, room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.services self.services
.state .state
.get_room_shortstatehash(room_id) .get_room_shortstatehash(room_id)
@ -192,7 +195,7 @@ impl Data {
/// Returns the full room state's pdus. /// Returns the full room state's pdus.
#[allow(unused_qualifications)] // async traits #[allow(unused_qualifications)] // async traits
pub(super) async fn room_state_full_pdus(&self, room_id: &RoomId) -> Result<Vec<Arc<PduEvent>>> { pub(super) async fn room_state_full_pdus(&self, room_id: &RoomId) -> Result<Vec<PduEvent>> {
self.services self.services
.state .state
.get_room_shortstatehash(room_id) .get_room_shortstatehash(room_id)
@ -215,7 +218,7 @@ impl Data {
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn room_state_get( pub(super) async fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Arc<PduEvent>> { ) -> Result<PduEvent> {
self.services self.services
.state .state
.get_room_shortstatehash(room_id) .get_room_shortstatehash(room_id)

View file

@ -114,7 +114,7 @@ impl Service {
pub async fn state_full( pub async fn state_full(
&self, shortstatehash: ShortStateHash, &self, shortstatehash: ShortStateHash,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { ) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.db.state_full(shortstatehash).await self.db.state_full(shortstatehash).await
} }
@ -134,7 +134,7 @@ impl Service {
#[inline] #[inline]
pub async fn state_get( pub async fn state_get(
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<Arc<PduEvent>> { ) -> Result<PduEvent> {
self.db self.db
.state_get(shortstatehash, event_type, state_key) .state_get(shortstatehash, event_type, state_key)
.await .await
@ -311,13 +311,13 @@ impl Service {
/// Returns the full room state. /// Returns the full room state.
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> { pub async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), PduEvent>> {
self.db.room_state_full(room_id).await self.db.room_state_full(room_id).await
} }
/// Returns the full room state pdus /// Returns the full room state pdus
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_full_pdus(&self, room_id: &RoomId) -> Result<Vec<Arc<PduEvent>>> { pub async fn room_state_full_pdus(&self, room_id: &RoomId) -> Result<Vec<PduEvent>> {
self.db.room_state_full_pdus(room_id).await self.db.room_state_full_pdus(room_id).await
} }
@ -337,7 +337,7 @@ impl Service {
#[tracing::instrument(skip(self), level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub async fn room_state_get( pub async fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Arc<PduEvent>> { ) -> Result<PduEvent> {
self.db.room_state_get(room_id, event_type, state_key).await self.db.room_state_get(room_id, event_type, state_key).await
} }

View file

@ -126,14 +126,7 @@ impl Data {
/// Returns the pdu. /// Returns the pdu.
/// ///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline. /// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub(super) async fn get_pdu(&self, event_id: &EventId) -> Result<Arc<PduEvent>> { pub(super) async fn get_pdu(&self, event_id: &EventId) -> Result<PduEvent> {
self.get_pdu_owned(event_id).await.map(Arc::new)
}
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub(super) async fn get_pdu_owned(&self, event_id: &EventId) -> Result<PduEvent> {
let accepted = self.get_non_outlier_pdu(event_id).boxed(); let accepted = self.get_non_outlier_pdu(event_id).boxed();
let outlier = self let outlier = self
.eventid_outlierpdu .eventid_outlierpdu

View file

@ -242,12 +242,7 @@ impl Service {
/// Returns the pdu. /// Returns the pdu.
/// ///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline. /// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub async fn get_pdu(&self, event_id: &EventId) -> Result<Arc<PduEvent>> { self.db.get_pdu(event_id).await } pub async fn get_pdu(&self, event_id: &EventId) -> Result<PduEvent> { self.db.get_pdu(event_id).await }
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub async fn get_pdu_owned(&self, event_id: &EventId) -> Result<PduEvent> { self.db.get_pdu_owned(event_id).await }
/// Checks if pdu exists /// Checks if pdu exists
/// ///
@ -327,11 +322,11 @@ impl Service {
); );
unsigned.insert( unsigned.insert(
String::from("prev_sender"), String::from("prev_sender"),
CanonicalJsonValue::String(prev_state.sender.clone().to_string()), CanonicalJsonValue::String(prev_state.sender.to_string()),
); );
unsigned.insert( unsigned.insert(
String::from("replaces_state"), String::from("replaces_state"),
CanonicalJsonValue::String(prev_state.event_id.clone().to_string()), CanonicalJsonValue::String(prev_state.event_id.to_string()),
); );
} }
} }