diff --git a/Cargo.lock b/Cargo.lock index 5848cc46..3a435a10 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -685,6 +685,7 @@ dependencies = [ "http-body-util", "hyper", "ipaddress", + "itertools 0.13.0", "log", "rand", "reqwest", diff --git a/src/api/Cargo.toml b/src/api/Cargo.toml index 1b463fbc..385e786f 100644 --- a/src/api/Cargo.toml +++ b/src/api/Cargo.toml @@ -50,6 +50,7 @@ http.workspace = true http-body-util.workspace = true hyper.workspace = true ipaddress.workspace = true +itertools.workspace = true log.workspace = true rand.workspace = true reqwest.workspace = true diff --git a/src/api/server/send.rs b/src/api/server/send.rs index eec9bd11..016f5194 100644 --- a/src/api/server/send.rs +++ b/src/api/server/send.rs @@ -3,10 +3,17 @@ use std::{collections::BTreeMap, net::IpAddr, time::Instant}; use axum::extract::State; use axum_client_ip::InsecureClientIp; use conduwuit::{ - debug, debug_warn, err, error, result::LogErr, trace, utils::ReadyExt, warn, Err, Error, - Result, + debug, debug_warn, err, error, + result::LogErr, + trace, + utils::{ + stream::{automatic_width, BroadbandExt, TryBroadbandExt}, + IterStream, ReadyExt, + }, + warn, Err, Error, Result, }; -use futures::{FutureExt, StreamExt}; +use futures::{FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt}; +use itertools::Itertools; use ruma::{ api::{ client::error::ErrorKind, @@ -19,11 +26,9 @@ use ruma::{ }, }, events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType}, - serde::Raw, to_device::DeviceIdOrAllDevices, - OwnedEventId, ServerName, + CanonicalJsonObject, OwnedEventId, OwnedRoomId, ServerName, }; -use serde_json::value::RawValue as RawJsonValue; use service::{ sending::{EDU_LIMIT, PDU_LIMIT}, Services, @@ -34,7 +39,8 @@ use crate::{ Ruma, }; -type ResolvedMap = BTreeMap>; +type ResolvedMap = BTreeMap; +type Pdu = (OwnedRoomId, OwnedEventId, CanonicalJsonObject); /// # `PUT /_matrix/federation/v1/send/{txnId}` /// @@ -73,91 +79,41 @@ pub(crate) async fn send_transaction_message_route( let txn_start_time = Instant::now(); trace!( - pdus = ?body.pdus.len(), - edus = ?body.edus.len(), + pdus = body.pdus.len(), + edus = body.edus.len(), elapsed = ?txn_start_time.elapsed(), id = ?body.transaction_id, origin =?body.origin(), "Starting txn", ); - let resolved_map = - handle_pdus(&services, &client, &body.pdus, body.origin(), &txn_start_time) - .boxed() - .await?; + let pdus = body + .pdus + .iter() + .stream() + .broad_then(|pdu| services.rooms.event_handler.parse_incoming_pdu(pdu)) + .inspect_err(|e| debug_warn!("Could not parse PDU: {e}")) + .ready_filter_map(Result::ok); - handle_edus(&services, &client, &body.edus, body.origin()) - .boxed() - .await; + let edus = body + .edus + .iter() + .map(|edu| edu.json().get()) + .map(serde_json::from_str) + .filter_map(Result::ok) + .stream(); + + let results = handle(&services, &client, body.origin(), txn_start_time, pdus, edus).await?; debug!( - pdus = ?body.pdus.len(), - edus = ?body.edus.len(), + pdus = body.pdus.len(), + edus = body.edus.len(), elapsed = ?txn_start_time.elapsed(), id = ?body.transaction_id, origin =?body.origin(), "Finished txn", ); - - Ok(send_transaction_message::v1::Response { - pdus: resolved_map - .into_iter() - .map(|(e, r)| (e, r.map_err(error::sanitized_message))) - .collect(), - }) -} - -async fn handle_pdus( - services: &Services, - _client: &IpAddr, - pdus: &[Box], - origin: &ServerName, - txn_start_time: &Instant, -) -> Result { - let mut parsed_pdus = Vec::with_capacity(pdus.len()); - for pdu in pdus { - parsed_pdus.push(match services.rooms.event_handler.parse_incoming_pdu(pdu).await { - | Ok(t) => t, - | Err(e) => { - debug_warn!("Could not parse PDU: {e}"); - continue; - }, - }); - - // We do not add the event_id field to the pdu here because of signature - // and hashes checks - } - - let mut resolved_map = BTreeMap::new(); - for (event_id, value, room_id) in parsed_pdus { - services.server.check_running()?; - let pdu_start_time = Instant::now(); - let mutex_lock = services - .rooms - .event_handler - .mutex_federation - .lock(&room_id) - .await; - - let result = services - .rooms - .event_handler - .handle_incoming_pdu(origin, &room_id, &event_id, value, true) - .boxed() - .await - .map(|_| ()); - - drop(mutex_lock); - debug!( - pdu_elapsed = ?pdu_start_time.elapsed(), - txn_elapsed = ?txn_start_time.elapsed(), - "Finished PDU {event_id}", - ); - - resolved_map.insert(event_id, result); - } - - for (id, result) in &resolved_map { + for (id, result) in &results { if let Err(e) = result { if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) { warn!("Incoming PDU failed {id}: {e:?}"); @@ -165,39 +121,112 @@ async fn handle_pdus( } } - Ok(resolved_map) + Ok(send_transaction_message::v1::Response { + pdus: results + .into_iter() + .map(|(e, r)| (e, r.map_err(error::sanitized_message))) + .collect(), + }) } -async fn handle_edus( +async fn handle( services: &Services, client: &IpAddr, - edus: &[Raw], origin: &ServerName, -) { - for edu in edus - .iter() - .filter_map(|edu| serde_json::from_str::(edu.json().get()).ok()) - { - match edu { - | Edu::Presence(presence) => { - handle_edu_presence(services, client, origin, presence).await; - }, - | Edu::Receipt(receipt) => - handle_edu_receipt(services, client, origin, receipt).await, - | Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await, - | Edu::DeviceListUpdate(content) => { - handle_edu_device_list_update(services, client, origin, content).await; - }, - | Edu::DirectToDevice(content) => { - handle_edu_direct_to_device(services, client, origin, content).await; - }, - | Edu::SigningKeyUpdate(content) => { - handle_edu_signing_key_update(services, client, origin, content).await; - }, - | Edu::_Custom(ref _custom) => { - debug_warn!(?edus, "received custom/unknown EDU"); - }, - } + started: Instant, + pdus: impl Stream + Send, + edus: impl Stream + Send, +) -> Result { + // group pdus by room + let pdus = pdus + .collect() + .map(|mut pdus: Vec<_>| { + pdus.sort_by(|(room_a, ..), (room_b, ..)| room_a.cmp(room_b)); + pdus.into_iter() + .into_grouping_map_by(|(room_id, ..)| room_id.clone()) + .collect() + }) + .await; + + // we can evaluate rooms concurrently + let results: ResolvedMap = pdus + .into_iter() + .try_stream() + .broad_and_then(|(room_id, pdus)| { + handle_room(services, client, origin, started, room_id, pdus) + .map_ok(Vec::into_iter) + .map_ok(IterStream::try_stream) + }) + .try_flatten() + .try_collect() + .boxed() + .await?; + + // evaluate edus after pdus, at least for now. + edus.for_each_concurrent(automatic_width(), |edu| handle_edu(services, client, origin, edu)) + .boxed() + .await; + + Ok(results) +} + +async fn handle_room( + services: &Services, + _client: &IpAddr, + origin: &ServerName, + txn_start_time: Instant, + room_id: OwnedRoomId, + pdus: Vec, +) -> Result> { + let _room_lock = services + .rooms + .event_handler + .mutex_federation + .lock(&room_id) + .await; + + let mut results = Vec::with_capacity(pdus.len()); + for (_, event_id, value) in pdus { + services.server.check_running()?; + let pdu_start_time = Instant::now(); + let result = services + .rooms + .event_handler + .handle_incoming_pdu(origin, &room_id, &event_id, value, true) + .await + .map(|_| ()); + + debug!( + pdu_elapsed = ?pdu_start_time.elapsed(), + txn_elapsed = ?txn_start_time.elapsed(), + "Finished PDU {event_id}", + ); + + results.push((event_id, result)); + } + + Ok(results) +} + +async fn handle_edu(services: &Services, client: &IpAddr, origin: &ServerName, edu: Edu) { + match edu { + | Edu::Presence(presence) => { + handle_edu_presence(services, client, origin, presence).await; + }, + | Edu::Receipt(receipt) => handle_edu_receipt(services, client, origin, receipt).await, + | Edu::Typing(typing) => handle_edu_typing(services, client, origin, typing).await, + | Edu::DeviceListUpdate(content) => { + handle_edu_device_list_update(services, client, origin, content).await; + }, + | Edu::DirectToDevice(content) => { + handle_edu_direct_to_device(services, client, origin, content).await; + }, + | Edu::SigningKeyUpdate(content) => { + handle_edu_signing_key_update(services, client, origin, content).await; + }, + | Edu::_Custom(ref _custom) => { + debug_warn!(?edu, "received custom/unknown EDU"); + }, } } diff --git a/src/service/rooms/event_handler/parse_incoming_pdu.rs b/src/service/rooms/event_handler/parse_incoming_pdu.rs index 0c11314d..9b130763 100644 --- a/src/service/rooms/event_handler/parse_incoming_pdu.rs +++ b/src/service/rooms/event_handler/parse_incoming_pdu.rs @@ -2,11 +2,10 @@ use conduwuit::{err, implement, pdu::gen_event_id_canonical_json, result::FlatOk use ruma::{CanonicalJsonObject, CanonicalJsonValue, OwnedEventId, OwnedRoomId}; use serde_json::value::RawValue as RawJsonValue; +type Parsed = (OwnedRoomId, OwnedEventId, CanonicalJsonObject); + #[implement(super::Service)] -pub async fn parse_incoming_pdu( - &self, - pdu: &RawJsonValue, -) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { +pub async fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result { let value = serde_json::from_str::(pdu.get()).map_err(|e| { err!(BadServerResponse(debug_warn!("Error parsing incoming event {e:?}"))) })?; @@ -28,5 +27,5 @@ pub async fn parse_incoming_pdu( err!(Request(InvalidParam("Could not convert event to canonical json: {e}"))) })?; - Ok((event_id, value, room_id)) + Ok((room_id, event_id, value)) } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 362bfab5..bf585a6b 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1166,7 +1166,7 @@ impl Service { #[tracing::instrument(skip(self, pdu), level = "debug")] pub async fn backfill_pdu(&self, origin: &ServerName, pdu: Box) -> Result<()> { - let (event_id, value, room_id) = + let (room_id, event_id, value) = self.services.event_handler.parse_incoming_pdu(&pdu).await?; // Lock so we cannot backfill the same pdu twice at the same time