diff --git a/conduwuit-example.toml b/conduwuit-example.toml index a82d8f69..3669961a 100644 --- a/conduwuit-example.toml +++ b/conduwuit-example.toml @@ -1410,6 +1410,13 @@ # #db_pool_queue_size = 256 +# Number of sender task workers; determines sender parallelism. Default is +# '0' which means the value is determined internally, likely matching the +# number of tokio worker-threads or number of cores, etc. Override by +# setting a non-zero value. +# +#sender_workers = 0 + [global.tls] # Path to a valid TLS certificate file. diff --git a/src/admin/query/sending.rs b/src/admin/query/sending.rs index 696067b7..3edbbe87 100644 --- a/src/admin/query/sending.rs +++ b/src/admin/query/sending.rs @@ -113,7 +113,7 @@ pub(super) async fn process( | (None, Some(server_name), None, None) => services .sending .db - .queued_requests(&Destination::Normal(server_name.into())), + .queued_requests(&Destination::Federation(server_name.into())), | (None, None, Some(user_id), Some(push_key)) => { if push_key.is_empty() { return Ok(RoomMessageEventContent::text_plain( @@ -183,7 +183,7 @@ pub(super) async fn process( | (None, Some(server_name), None, None) => services .sending .db - .active_requests_for(&Destination::Normal(server_name.into())), + .active_requests_for(&Destination::Federation(server_name.into())), | (None, None, Some(user_id), Some(push_key)) => { if push_key.is_empty() { return Ok(RoomMessageEventContent::text_plain( diff --git a/src/core/config/mod.rs b/src/core/config/mod.rs index 8fd5621f..23feb0ca 100644 --- a/src/core/config/mod.rs +++ b/src/core/config/mod.rs @@ -1598,6 +1598,15 @@ pub struct Config { #[serde(default = "default_db_pool_queue_size")] pub db_pool_queue_size: usize, + /// Number of sender task workers; determines sender parallelism. Default is + /// '0' which means the value is determined internally, likely matching the + /// number of tokio worker-threads or number of cores, etc. Override by + /// setting a non-zero value. + /// + /// default: 0 + #[serde(default)] + pub sender_workers: usize, + #[serde(flatten)] #[allow(clippy::zero_sized_map_values)] // this is a catchall, the map shouldn't be zero at runtime diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index 63de6539..3f609b95 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -4,6 +4,7 @@ use std::{ }; use conduwuit::{debug, debug_error, debug_info, debug_warn, err, error, trace, Err, Result}; +use futures::FutureExt; use hickory_resolver::error::ResolveError; use ipaddress::IPAddress; use ruma::ServerName; @@ -32,7 +33,7 @@ impl super::Service { (result, true) } else { self.validate_dest(server_name)?; - (self.resolve_actual_dest(server_name, true).await?, false) + (self.resolve_actual_dest(server_name, true).boxed().await?, false) }; let CachedDest { dest, host, .. } = result; diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index ac06424f..a699b8ee 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -246,7 +246,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se })?; ( - Destination::Normal(ServerName::parse(server).map_err(|_| { + Destination::Federation(ServerName::parse(server).map_err(|_| { Error::bad_database("Invalid server string in server_currenttransaction") })?), if value.is_empty() { diff --git a/src/service/sending/dest.rs b/src/service/sending/dest.rs index 2c6063cc..4099d372 100644 --- a/src/service/sending/dest.rs +++ b/src/service/sending/dest.rs @@ -7,14 +7,14 @@ use ruma::{OwnedServerName, OwnedUserId}; pub enum Destination { Appservice(String), Push(OwnedUserId, String), // user and pushkey - Normal(OwnedServerName), + Federation(OwnedServerName), } #[implement(Destination)] #[must_use] pub(super) fn get_prefix(&self) -> Vec { match self { - | Self::Normal(server) => { + | Self::Federation(server) => { let len = server.as_bytes().len().saturating_add(1); let mut p = Vec::with_capacity(len); diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 2038f4eb..5ccba249 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -4,20 +4,25 @@ mod dest; mod send; mod sender; -use std::{fmt::Debug, iter::once, sync::Arc}; +use std::{ + fmt::Debug, + hash::{DefaultHasher, Hash, Hasher}, + iter::once, + sync::Arc, +}; use async_trait::async_trait; use conduwuit::{ - debug_warn, err, - utils::{ReadyExt, TryReadyExt}, + debug, debug_warn, err, error, + utils::{available_parallelism, math::usize_from_u64_truncated, ReadyExt, TryReadyExt}, warn, Result, Server, }; -use futures::{Stream, StreamExt}; +use futures::{FutureExt, Stream, StreamExt}; use ruma::{ api::{appservice::Registration, OutgoingRequest}, RoomId, ServerName, UserId, }; -use tokio::sync::Mutex; +use tokio::task::JoinSet; use self::data::Data; pub use self::{ @@ -30,11 +35,10 @@ use crate::{ }; pub struct Service { + pub db: Data, server: Arc, services: Services, - pub db: Data, - sender: loole::Sender, - receiver: Mutex>, + channels: Vec<(loole::Sender, loole::Receiver)>, } struct Services { @@ -72,8 +76,9 @@ pub enum SendingEvent { #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { - let (sender, receiver) = loole::unbounded(); + let num_senders = num_senders(&args); Ok(Arc::new(Self { + db: Data::new(&args), server: args.server.clone(), services: Services { client: args.depend::("client"), @@ -91,20 +96,41 @@ impl crate::Service for Service { pusher: args.depend::("pusher"), server_keys: args.depend::("server_keys"), }, - db: Data::new(&args), - sender, - receiver: Mutex::new(receiver), + channels: (0..num_senders).map(|_| loole::unbounded()).collect(), })) } - async fn worker(self: Arc) -> Result<()> { - // trait impl can't be split between files so this just glues to mod sender - self.sender().await + async fn worker(self: Arc) -> Result { + let mut senders = + self.channels + .iter() + .enumerate() + .fold(JoinSet::new(), |mut joinset, (id, _)| { + let self_ = self.clone(); + let runtime = self.server.runtime(); + let _abort = joinset.spawn_on(self_.sender(id).boxed(), runtime); + joinset + }); + + while let Some(ret) = senders.join_next_with_id().await { + match ret { + | Ok((id, _)) => { + debug!(?id, "sender worker finished"); + }, + | Err(error) => { + error!(id = ?error.id(), ?error, "sender worker finished"); + }, + }; + } + + Ok(()) } fn interrupt(&self) { - if !self.sender.is_closed() { - self.sender.close(); + for (sender, _) in &self.channels { + if !sender.is_closed() { + sender.close(); + } } } @@ -157,7 +183,7 @@ impl Service { let _cork = self.db.db.cork(); let requests = servers .map(|server| { - (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.to_owned())) + (Destination::Federation(server.into()), SendingEvent::Pdu(pdu_id.to_owned())) }) .collect::>() .await; @@ -173,7 +199,7 @@ impl Service { #[tracing::instrument(skip(self, server, serialized), level = "debug")] pub fn send_edu_server(&self, server: &ServerName, serialized: Vec) -> Result<()> { - let dest = Destination::Normal(server.to_owned()); + let dest = Destination::Federation(server.to_owned()); let event = SendingEvent::Edu(serialized); let _cork = self.db.db.cork(); let keys = self.db.queue_requests(once((&event, &dest))); @@ -203,7 +229,10 @@ impl Service { let _cork = self.db.db.cork(); let requests = servers .map(|server| { - (Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone())) + ( + Destination::Federation(server.to_owned()), + SendingEvent::Edu(serialized.clone()), + ) }) .collect::>() .await; @@ -235,7 +264,7 @@ impl Service { { servers .map(ToOwned::to_owned) - .map(Destination::Normal) + .map(Destination::Federation) .map(Ok) .ready_try_for_each(|dest| { self.dispatch(Msg { @@ -327,9 +356,49 @@ impl Service { } } - fn dispatch(&self, msg: Msg) -> Result<()> { - debug_assert!(!self.sender.is_full(), "channel full"); - debug_assert!(!self.sender.is_closed(), "channel closed"); - self.sender.send(msg).map_err(|e| err!("{e}")) + fn dispatch(&self, msg: Msg) -> Result { + let shard = self.shard_id(&msg.dest); + let sender = &self + .channels + .get(shard) + .expect("missing sender worker channels") + .0; + + debug_assert!(!sender.is_full(), "channel full"); + debug_assert!(!sender.is_closed(), "channel closed"); + sender.send(msg).map_err(|e| err!("{e}")) + } + + pub(super) fn shard_id(&self, dest: &Destination) -> usize { + if self.channels.len() <= 1 { + return 0; + } + + let mut hash = DefaultHasher::default(); + dest.hash(&mut hash); + + let hash: u64 = hash.finish(); + let hash = usize_from_u64_truncated(hash); + + let chans = self.channels.len().max(1); + hash.overflowing_rem(chans).0 } } + +fn num_senders(args: &crate::Args<'_>) -> usize { + const MIN_SENDERS: usize = 1; + // Limit the number of senders to the number of workers threads or number of + // cores, conservatively. + let max_senders = args + .server + .metrics + .num_workers() + .min(available_parallelism()); + + // If the user doesn't override the default 0, this is intended to then default + // to 1 for now as multiple senders is experimental. + args.server + .config + .sender_workers + .clamp(MIN_SENDERS, max_senders) +} diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 81467c16..3d13a3b0 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -24,7 +24,10 @@ use crate::{ }; impl super::Service { - #[tracing::instrument(skip_all, level = "debug")] + #[tracing::instrument( + level = "debug" + skip(self, client, request), + )] pub async fn send( &self, client: &Client, diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index f6b83e83..4e806ce8 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -1,7 +1,10 @@ use std::{ collections::{BTreeMap, HashMap, HashSet}, fmt::Debug, - sync::atomic::{AtomicU64, AtomicUsize, Ordering}, + sync::{ + atomic::{AtomicU64, AtomicUsize, Ordering}, + Arc, + }, time::{Duration, Instant}, }; @@ -66,29 +69,56 @@ pub const PDU_LIMIT: usize = 50; pub const EDU_LIMIT: usize = 100; impl Service { - #[tracing::instrument(skip_all, level = "debug")] - pub(super) async fn sender(&self) -> Result<()> { + #[tracing::instrument(skip(self), level = "debug")] + pub(super) async fn sender(self: Arc, id: usize) -> Result { let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); let mut futures: SendingFutures<'_> = FuturesUnordered::new(); - let receiver = self.receiver.lock().await; - self.initial_requests(&mut futures, &mut statuses).await; - while !receiver.is_closed() { - tokio::select! { - request = receiver.recv_async() => match request { - Ok(request) => self.handle_request(request, &mut futures, &mut statuses).await, - Err(_) => break, - }, - Some(response) = futures.next() => { - self.handle_response(response, &mut futures, &mut statuses).await; - }, - } - } - self.finish_responses(&mut futures).await; + self.startup_netburst(id, &mut futures, &mut statuses) + .boxed() + .await; + + self.work_loop(id, &mut futures, &mut statuses).await; + + self.finish_responses(&mut futures).boxed().await; Ok(()) } + #[tracing::instrument( + name = "work", + level = "trace" + skip_all, + fields( + futures = %futures.len(), + statuses = %statuses.len(), + ), + )] + async fn work_loop<'a>( + &'a self, + id: usize, + futures: &mut SendingFutures<'a>, + statuses: &mut CurTransactionStatus, + ) { + let receiver = self + .channels + .get(id) + .map(|(_, receiver)| receiver.clone()) + .expect("Missing channel for sender worker"); + loop { + tokio::select! { + Some(response) = futures.next() => { + self.handle_response(response, futures, statuses).await; + }, + request = receiver.recv_async() => match request { + Ok(request) => self.handle_request(request, futures, statuses).await, + Err(_) => return, + }, + } + } + } + + #[tracing::instrument(name = "response", level = "debug", skip_all)] async fn handle_response<'a>( &'a self, response: SendingResult, @@ -138,13 +168,14 @@ impl Service { self.db.mark_as_active(new_events.iter()); let new_events_vec = new_events.into_iter().map(|(_, event)| event).collect(); - futures.push(self.send_events(dest.clone(), new_events_vec).boxed()); + futures.push(self.send_events(dest.clone(), new_events_vec)); } else { statuses.remove(dest); } } #[allow(clippy::needless_pass_by_ref_mut)] + #[tracing::instrument(name = "request", level = "debug", skip_all)] async fn handle_request<'a>( &'a self, msg: Msg, @@ -154,13 +185,19 @@ impl Service { let iv = vec![(msg.queue_id, msg.event)]; if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await { if !events.is_empty() { - futures.push(self.send_events(msg.dest, events).boxed()); + futures.push(self.send_events(msg.dest, events)); } else { statuses.remove(&msg.dest); } } } + #[tracing::instrument( + name = "finish", + level = "info", + skip_all, + fields(futures = %futures.len()), + )] async fn finish_responses<'a>(&'a self, futures: &mut SendingFutures<'a>) { use tokio::{ select, @@ -183,9 +220,16 @@ impl Service { } } + #[tracing::instrument( + name = "netburst", + level = "debug", + skip_all, + fields(futures = %futures.len()), + )] #[allow(clippy::needless_pass_by_ref_mut)] - async fn initial_requests<'a>( + async fn startup_netburst<'a>( &'a self, + id: usize, futures: &mut SendingFutures<'a>, statuses: &mut CurTransactionStatus, ) { @@ -195,6 +239,10 @@ impl Service { let mut active = self.db.active_requests().boxed(); while let Some((key, event, dest)) = active.next().await { + if self.shard_id(&dest) != id { + continue; + } + let entry = txns.entry(dest.clone()).or_default(); if self.server.config.startup_netburst_keep >= 0 && entry.len() >= keep { warn!("Dropping unsent event {dest:?} {:?}", String::from_utf8_lossy(&key)); @@ -207,19 +255,27 @@ impl Service { for (dest, events) in txns { if self.server.config.startup_netburst && !events.is_empty() { statuses.insert(dest.clone(), TransactionStatus::Running); - futures.push(self.send_events(dest.clone(), events).boxed()); + futures.push(self.send_events(dest.clone(), events)); } } } - #[tracing::instrument(skip_all, level = "debug")] + #[tracing::instrument( + name = "select",, + level = "debug", + skip_all, + fields( + ?dest, + new_events = %new_events.len(), + ) + )] async fn select_events( &self, dest: &Destination, new_events: Vec, // Events we want to send: event and full key statuses: &mut CurTransactionStatus, ) -> Result>> { - let (allow, retry) = self.select_events_current(dest.clone(), statuses)?; + let (allow, retry) = self.select_events_current(dest, statuses)?; // Nothing can be done for this remote, bail out. if !allow { @@ -249,7 +305,7 @@ impl Service { } // Add EDU's into the transaction - if let Destination::Normal(server_name) = dest { + if let Destination::Federation(server_name) = dest { if let Ok((select_edus, last_count)) = self.select_edus(server_name).await { debug_assert!(select_edus.len() <= EDU_LIMIT, "exceeded edus limit"); events.extend(select_edus.into_iter().map(SendingEvent::Edu)); @@ -260,10 +316,9 @@ impl Service { Ok(Some(events)) } - #[tracing::instrument(skip_all, level = "debug")] fn select_events_current( &self, - dest: Destination, + dest: &Destination, statuses: &mut CurTransactionStatus, ) -> Result<(bool, bool)> { let (mut allow, mut retry) = (true, false); @@ -292,7 +347,11 @@ impl Service { Ok((allow, retry)) } - #[tracing::instrument(skip_all, level = "debug")] + #[tracing::instrument( + name = "edus",, + level = "debug", + skip_all, + )] async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { // selection window let since = self.db.get_latest_educount(server_name).await; @@ -329,7 +388,12 @@ impl Service { Ok((events, max_edu_count.load(Ordering::Acquire))) } - /// Look for presence + /// Look for device changes + #[tracing::instrument( + name = "device_changes", + level = "trace", + skip(self, server_name, max_edu_count) + )] async fn select_edus_device_changes( &self, server_name: &ServerName, @@ -386,6 +450,11 @@ impl Service { } /// Look for read receipts in this room + #[tracing::instrument( + name = "receipts", + level = "trace", + skip(self, server_name, max_edu_count) + )] async fn select_edus_receipts( &self, server_name: &ServerName, @@ -420,6 +489,7 @@ impl Service { } /// Look for read receipts in this room + #[tracing::instrument(name = "receipts", level = "trace", skip(self, since, max_edu_count))] async fn select_edus_receipts_room( &self, room_id: &RoomId, @@ -484,6 +554,11 @@ impl Service { } /// Look for presence + #[tracing::instrument( + name = "presence", + level = "trace", + skip(self, server_name, max_edu_count) + )] async fn select_edus_presence( &self, server_name: &ServerName, @@ -554,29 +629,33 @@ impl Service { Some(presence_content) } - async fn send_events(&self, dest: Destination, events: Vec) -> SendingResult { + fn send_events(&self, dest: Destination, events: Vec) -> SendingFuture<'_> { //debug_assert!(!events.is_empty(), "sending empty transaction"); match dest { - | Destination::Normal(ref server) => - self.send_events_dest_normal(&dest, server, events).await, - | Destination::Appservice(ref id) => - self.send_events_dest_appservice(&dest, id, events).await, - | Destination::Push(ref userid, ref pushkey) => - self.send_events_dest_push(&dest, userid, pushkey, events) - .await, + | Destination::Federation(server) => + self.send_events_dest_federation(server, events).boxed(), + | Destination::Appservice(id) => self.send_events_dest_appservice(id, events).boxed(), + | Destination::Push(user_id, pushkey) => + self.send_events_dest_push(user_id, pushkey, events).boxed(), } } - #[tracing::instrument(skip(self, dest, events), name = "appservice")] + #[tracing::instrument( + name = "appservice", + level = "debug", + skip(self, events), + fields( + events = %events.len(), + ), + )] async fn send_events_dest_appservice( &self, - dest: &Destination, - id: &str, + id: String, events: Vec, ) -> SendingResult { - let Some(appservice) = self.services.appservice.get_registration(id).await else { + let Some(appservice) = self.services.appservice.get_registration(&id).await else { return Err(( - dest.clone(), + Destination::Appservice(id.clone()), err!(Database(warn!(?id, "Missing appservice registration"))), )); }; @@ -633,23 +712,29 @@ impl Service { ) .await { - | Ok(_) => Ok(dest.clone()), - | Err(e) => Err((dest.clone(), e)), + | Ok(_) => Ok(Destination::Appservice(id)), + | Err(e) => Err((Destination::Appservice(id), e)), } } - #[tracing::instrument(skip(self, dest, events), name = "push")] + #[tracing::instrument( + name = "push", + level = "info", + skip(self, events), + fields( + events = %events.len(), + ), + )] async fn send_events_dest_push( &self, - dest: &Destination, - userid: &OwnedUserId, - pushkey: &str, + user_id: OwnedUserId, + pushkey: String, events: Vec, ) -> SendingResult { - let Ok(pusher) = self.services.pusher.get_pusher(userid, pushkey).await else { + let Ok(pusher) = self.services.pusher.get_pusher(&user_id, &pushkey).await else { return Err(( - dest.clone(), - err!(Database(error!(?userid, ?pushkey, "Missing pusher"))), + Destination::Push(user_id.clone(), pushkey.clone()), + err!(Database(error!(?user_id, ?pushkey, "Missing pusher"))), )); }; @@ -677,17 +762,17 @@ impl Service { let rules_for_user = self .services .account_data - .get_global(userid, GlobalAccountDataEventType::PushRules) + .get_global(&user_id, GlobalAccountDataEventType::PushRules) .await .map_or_else( - |_| push::Ruleset::server_default(userid), + |_| push::Ruleset::server_default(&user_id), |ev: PushRulesEvent| ev.content.global, ); let unread: UInt = self .services .user - .notification_count(userid, &pdu.room_id) + .notification_count(&user_id, &pdu.room_id) .await .try_into() .expect("notification count can't go that high"); @@ -695,19 +780,25 @@ impl Service { let _response = self .services .pusher - .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) + .send_push_notice(&user_id, unread, &pusher, rules_for_user, &pdu) .await - .map_err(|e| (dest.clone(), e)); + .map_err(|e| (Destination::Push(user_id.clone(), pushkey.clone()), e)); } - Ok(dest.clone()) + Ok(Destination::Push(user_id, pushkey)) } - #[tracing::instrument(skip(self, dest, events), name = "", level = "debug")] - async fn send_events_dest_normal( + #[tracing::instrument( + name = "fed", + level = "debug", + skip(self, events), + fields( + events = %events.len(), + ), + )] + async fn send_events_dest_federation( &self, - dest: &Destination, - server: &OwnedServerName, + server: OwnedServerName, events: Vec, ) -> SendingResult { let mut pdu_jsons = Vec::with_capacity( @@ -759,7 +850,7 @@ impl Service { }; let client = &self.services.client.sender; - self.send(client, server, request) + self.send(client, &server, request) .await .inspect(|response| { response @@ -770,8 +861,8 @@ impl Service { |(pdu_id, res)| warn!(%txn_id, %server, "error sending PDU {pdu_id} to remote server: {res:?}"), ); }) - .map(|_| dest.clone()) - .map_err(|e| (dest.clone(), e)) + .map_err(|e| (Destination::Federation(server.clone()), e)) + .map(|_| Destination::Federation(server)) } /// This does not return a full `Pdu` it is only to satisfy ruma's types.