diff --git a/Cargo.lock b/Cargo.lock index 3a435a10..e379aebb 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -845,6 +845,7 @@ dependencies = [ "serde_json", "serde_yaml", "sha2", + "smallvec", "termimad", "tokio", "tracing", diff --git a/src/api/client/to_device.rs b/src/api/client/to_device.rs index 2ded04e7..1b942fba 100644 --- a/src/api/client/to_device.rs +++ b/src/api/client/to_device.rs @@ -10,6 +10,7 @@ use ruma::{ }, to_device::DeviceIdOrAllDevices, }; +use service::sending::EduBuf; use crate::Ruma; @@ -42,18 +43,21 @@ pub(crate) async fn send_event_to_device_route( messages.insert(target_user_id.clone(), map); let count = services.globals.next_count()?; - services.sending.send_edu_server( - target_user_id.server_name(), - serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( - DirectDeviceContent { - sender: sender_user.clone(), - ev_type: body.event_type.clone(), - message_id: count.to_string().into(), - messages, - }, - )) - .expect("DirectToDevice EDU can be serialized"), - )?; + let mut buf = EduBuf::new(); + serde_json::to_writer( + &mut buf, + &federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent { + sender: sender_user.clone(), + ev_type: body.event_type.clone(), + message_id: count.to_string().into(), + messages, + }), + ) + .expect("DirectToDevice EDU can be serialized"); + + services + .sending + .send_edu_server(target_user_id.server_name(), buf)?; continue; } diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index 21fbb417..c4f75453 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -74,6 +74,7 @@ serde_json.workspace = true serde.workspace = true serde_yaml.workspace = true sha2.workspace = true +smallvec.workspace = true termimad.workspace = true termimad.optional = true tokio.workspace = true diff --git a/src/service/rooms/typing/mod.rs b/src/service/rooms/typing/mod.rs index a6123322..c710b33a 100644 --- a/src/service/rooms/typing/mod.rs +++ b/src/service/rooms/typing/mod.rs @@ -13,7 +13,7 @@ use ruma::{ }; use tokio::sync::{broadcast, RwLock}; -use crate::{globals, sending, users, Dep}; +use crate::{globals, sending, sending::EduBuf, users, Dep}; pub struct Service { server: Arc, @@ -228,12 +228,13 @@ impl Service { return Ok(()); } - let edu = Edu::Typing(TypingContent::new(room_id.to_owned(), user_id.to_owned(), typing)); + let content = TypingContent::new(room_id.to_owned(), user_id.to_owned(), typing); + let edu = Edu::Typing(content); - self.services - .sending - .send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing")) - .await?; + let mut buf = EduBuf::new(); + serde_json::to_writer(&mut buf, &edu).expect("Serialized Edu::Typing"); + + self.services.sending.send_edu_room(room_id, buf).await?; Ok(()) } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 436f633e..4dd2d5aa 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -202,7 +202,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se if value.is_empty() { SendingEvent::Pdu(event.into()) } else { - SendingEvent::Edu(value.to_vec()) + SendingEvent::Edu(value.into()) }, ) } else if key.starts_with(b"$") { @@ -230,7 +230,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se SendingEvent::Pdu(event.into()) } else { // I'm pretty sure this should never be called - SendingEvent::Edu(value.to_vec()) + SendingEvent::Edu(value.into()) }, ) } else { @@ -252,7 +252,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se if value.is_empty() { SendingEvent::Pdu(event.into()) } else { - SendingEvent::Edu(value.to_vec()) + SendingEvent::Edu(value.into()) }, ) }) diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 80bca112..b146ad49 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -21,6 +21,7 @@ use ruma::{ api::{appservice::Registration, OutgoingRequest}, RoomId, ServerName, UserId, }; +use smallvec::SmallVec; use tokio::task::JoinSet; use self::data::Data; @@ -67,10 +68,16 @@ struct Msg { #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum SendingEvent { Pdu(RawPduId), // pduid - Edu(Vec), // pdu json + Edu(EduBuf), // edu json Flush, // none } +pub type EduBuf = SmallVec<[u8; EDU_BUF_CAP]>; +pub type EduVec = SmallVec<[EduBuf; EDU_VEC_CAP]>; + +const EDU_BUF_CAP: usize = 128; +const EDU_VEC_CAP: usize = 1; + #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -177,7 +184,6 @@ impl Service { where S: Stream + Send + 'a, { - let _cork = self.db.db.cork(); let requests = servers .map(|server| { (Destination::Federation(server.into()), SendingEvent::Pdu(pdu_id.to_owned())) @@ -185,6 +191,7 @@ impl Service { .collect::>() .await; + let _cork = self.db.db.cork(); let keys = self.db.queue_requests(requests.iter().map(|(o, e)| (e, o))); for ((dest, event), queue_id) in requests.into_iter().zip(keys) { @@ -195,7 +202,7 @@ impl Service { } #[tracing::instrument(skip(self, server, serialized), level = "debug")] - pub fn send_edu_server(&self, server: &ServerName, serialized: Vec) -> Result<()> { + pub fn send_edu_server(&self, server: &ServerName, serialized: EduBuf) -> Result { let dest = Destination::Federation(server.to_owned()); let event = SendingEvent::Edu(serialized); let _cork = self.db.db.cork(); @@ -208,7 +215,7 @@ impl Service { } #[tracing::instrument(skip(self, room_id, serialized), level = "debug")] - pub async fn send_edu_room(&self, room_id: &RoomId, serialized: Vec) -> Result<()> { + pub async fn send_edu_room(&self, room_id: &RoomId, serialized: EduBuf) -> Result { let servers = self .services .state_cache @@ -219,11 +226,10 @@ impl Service { } #[tracing::instrument(skip(self, servers, serialized), level = "debug")] - pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: Vec) -> Result<()> + pub async fn send_edu_servers<'a, S>(&self, servers: S, serialized: EduBuf) -> Result where S: Stream + Send + 'a, { - let _cork = self.db.db.cork(); let requests = servers .map(|server| { ( @@ -234,6 +240,7 @@ impl Service { .collect::>() .await; + let _cork = self.db.db.cork(); let keys = self.db.queue_requests(requests.iter().map(|(o, e)| (e, o))); for ((dest, event), queue_id) in requests.into_iter().zip(keys) { diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 47be01f1..363bb994 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -45,7 +45,9 @@ use ruma::{ }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use super::{appservice, data::QueueItem, Destination, Msg, SendingEvent, Service}; +use super::{ + appservice, data::QueueItem, Destination, EduBuf, EduVec, Msg, SendingEvent, Service, +}; #[derive(Debug)] enum TransactionStatus { @@ -313,7 +315,12 @@ impl Service { if let Destination::Federation(server_name) = dest { if let Ok((select_edus, last_count)) = self.select_edus(server_name).await { debug_assert!(select_edus.len() <= EDU_LIMIT, "exceeded edus limit"); - events.extend(select_edus.into_iter().map(SendingEvent::Edu)); + let select_edus = select_edus + .into_iter() + .map(Into::into) + .map(SendingEvent::Edu); + + events.extend(select_edus); self.db.set_latest_educount(server_name, last_count); } } @@ -357,7 +364,7 @@ impl Service { level = "debug", skip_all, )] - async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { + async fn select_edus(&self, server_name: &ServerName) -> Result<(EduVec, u64)> { // selection window let since = self.db.get_latest_educount(server_name).await; let since_upper = self.services.globals.current_count()?; @@ -405,8 +412,8 @@ impl Service { since: (u64, u64), max_edu_count: &AtomicU64, events_len: &AtomicUsize, - ) -> Vec> { - let mut events = Vec::new(); + ) -> EduVec { + let mut events = EduVec::new(); let server_rooms = self.services.state_cache.server_rooms(server_name); pin_mut!(server_rooms); @@ -441,10 +448,11 @@ impl Service { keys: None, }); - let edu = serde_json::to_vec(&edu) + let mut buf = EduBuf::new(); + serde_json::to_writer(&mut buf, &edu) .expect("failed to serialize device list update to JSON"); - events.push(edu); + events.push(buf); if events_len.fetch_add(1, Ordering::Relaxed) >= SELECT_EDU_LIMIT - 1 { return events; } @@ -465,7 +473,7 @@ impl Service { server_name: &ServerName, since: (u64, u64), max_edu_count: &AtomicU64, - ) -> Option> { + ) -> Option { let server_rooms = self.services.state_cache.server_rooms(server_name); pin_mut!(server_rooms); @@ -487,10 +495,11 @@ impl Service { let receipt_content = Edu::Receipt(ReceiptContent { receipts }); - let receipt_content = serde_json::to_vec(&receipt_content) + let mut buf = EduBuf::new(); + serde_json::to_writer(&mut buf, &receipt_content) .expect("Failed to serialize Receipt EDU to JSON vec"); - Some(receipt_content) + Some(buf) } /// Look for read receipts in this room @@ -569,7 +578,7 @@ impl Service { server_name: &ServerName, since: (u64, u64), max_edu_count: &AtomicU64, - ) -> Option> { + ) -> Option { let presence_since = self.services.presence.presence_since(since.0); pin_mut!(presence_since); @@ -628,10 +637,11 @@ impl Service { push: presence_updates.into_values().collect(), }); - let presence_content = serde_json::to_vec(&presence_content) + let mut buf = EduBuf::new(); + serde_json::to_writer(&mut buf, &presence_content) .expect("failed to serialize Presence EDU to JSON"); - Some(presence_content) + Some(buf) } fn send_events(&self, dest: Destination, events: Vec) -> SendingFuture<'_> {