shard sender into multiple task workers by destination hash

rename Destination::Normal variant

tracing instruments

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-12-18 03:34:56 +00:00 committed by strawberry
parent 98e6c81e49
commit af3d6a2e37
9 changed files with 275 additions and 95 deletions

View file

@ -1410,6 +1410,13 @@
# #
#db_pool_queue_size = 256 #db_pool_queue_size = 256
# Number of sender task workers; determines sender parallelism. Default is
# '0' which means the value is determined internally, likely matching the
# number of tokio worker-threads or number of cores, etc. Override by
# setting a non-zero value.
#
#sender_workers = 0
[global.tls] [global.tls]
# Path to a valid TLS certificate file. # Path to a valid TLS certificate file.

View file

@ -113,7 +113,7 @@ pub(super) async fn process(
| (None, Some(server_name), None, None) => services | (None, Some(server_name), None, None) => services
.sending .sending
.db .db
.queued_requests(&Destination::Normal(server_name.into())), .queued_requests(&Destination::Federation(server_name.into())),
| (None, None, Some(user_id), Some(push_key)) => { | (None, None, Some(user_id), Some(push_key)) => {
if push_key.is_empty() { if push_key.is_empty() {
return Ok(RoomMessageEventContent::text_plain( return Ok(RoomMessageEventContent::text_plain(
@ -183,7 +183,7 @@ pub(super) async fn process(
| (None, Some(server_name), None, None) => services | (None, Some(server_name), None, None) => services
.sending .sending
.db .db
.active_requests_for(&Destination::Normal(server_name.into())), .active_requests_for(&Destination::Federation(server_name.into())),
| (None, None, Some(user_id), Some(push_key)) => { | (None, None, Some(user_id), Some(push_key)) => {
if push_key.is_empty() { if push_key.is_empty() {
return Ok(RoomMessageEventContent::text_plain( return Ok(RoomMessageEventContent::text_plain(

View file

@ -1598,6 +1598,15 @@ pub struct Config {
#[serde(default = "default_db_pool_queue_size")] #[serde(default = "default_db_pool_queue_size")]
pub db_pool_queue_size: usize, pub db_pool_queue_size: usize,
/// Number of sender task workers; determines sender parallelism. Default is
/// '0' which means the value is determined internally, likely matching the
/// number of tokio worker-threads or number of cores, etc. Override by
/// setting a non-zero value.
///
/// default: 0
#[serde(default)]
pub sender_workers: usize,
#[serde(flatten)] #[serde(flatten)]
#[allow(clippy::zero_sized_map_values)] #[allow(clippy::zero_sized_map_values)]
// this is a catchall, the map shouldn't be zero at runtime // this is a catchall, the map shouldn't be zero at runtime

View file

@ -4,6 +4,7 @@ use std::{
}; };
use conduwuit::{debug, debug_error, debug_info, debug_warn, err, error, trace, Err, Result}; use conduwuit::{debug, debug_error, debug_info, debug_warn, err, error, trace, Err, Result};
use futures::FutureExt;
use hickory_resolver::error::ResolveError; use hickory_resolver::error::ResolveError;
use ipaddress::IPAddress; use ipaddress::IPAddress;
use ruma::ServerName; use ruma::ServerName;
@ -32,7 +33,7 @@ impl super::Service {
(result, true) (result, true)
} else { } else {
self.validate_dest(server_name)?; self.validate_dest(server_name)?;
(self.resolve_actual_dest(server_name, true).await?, false) (self.resolve_actual_dest(server_name, true).boxed().await?, false)
}; };
let CachedDest { dest, host, .. } = result; let CachedDest { dest, host, .. } = result;

View file

@ -246,7 +246,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
})?; })?;
( (
Destination::Normal(ServerName::parse(server).map_err(|_| { Destination::Federation(ServerName::parse(server).map_err(|_| {
Error::bad_database("Invalid server string in server_currenttransaction") Error::bad_database("Invalid server string in server_currenttransaction")
})?), })?),
if value.is_empty() { if value.is_empty() {

View file

@ -7,14 +7,14 @@ use ruma::{OwnedServerName, OwnedUserId};
pub enum Destination { pub enum Destination {
Appservice(String), Appservice(String),
Push(OwnedUserId, String), // user and pushkey Push(OwnedUserId, String), // user and pushkey
Normal(OwnedServerName), Federation(OwnedServerName),
} }
#[implement(Destination)] #[implement(Destination)]
#[must_use] #[must_use]
pub(super) fn get_prefix(&self) -> Vec<u8> { pub(super) fn get_prefix(&self) -> Vec<u8> {
match self { match self {
| Self::Normal(server) => { | Self::Federation(server) => {
let len = server.as_bytes().len().saturating_add(1); let len = server.as_bytes().len().saturating_add(1);
let mut p = Vec::with_capacity(len); let mut p = Vec::with_capacity(len);

View file

@ -4,20 +4,25 @@ mod dest;
mod send; mod send;
mod sender; mod sender;
use std::{fmt::Debug, iter::once, sync::Arc}; use std::{
fmt::Debug,
hash::{DefaultHasher, Hash, Hasher},
iter::once,
sync::Arc,
};
use async_trait::async_trait; use async_trait::async_trait;
use conduwuit::{ use conduwuit::{
debug_warn, err, debug, debug_warn, err, error,
utils::{ReadyExt, TryReadyExt}, utils::{available_parallelism, math::usize_from_u64_truncated, ReadyExt, TryReadyExt},
warn, Result, Server, warn, Result, Server,
}; };
use futures::{Stream, StreamExt}; use futures::{FutureExt, Stream, StreamExt};
use ruma::{ use ruma::{
api::{appservice::Registration, OutgoingRequest}, api::{appservice::Registration, OutgoingRequest},
RoomId, ServerName, UserId, RoomId, ServerName, UserId,
}; };
use tokio::sync::Mutex; use tokio::task::JoinSet;
use self::data::Data; use self::data::Data;
pub use self::{ pub use self::{
@ -30,11 +35,10 @@ use crate::{
}; };
pub struct Service { pub struct Service {
pub db: Data,
server: Arc<Server>, server: Arc<Server>,
services: Services, services: Services,
pub db: Data, channels: Vec<(loole::Sender<Msg>, loole::Receiver<Msg>)>,
sender: loole::Sender<Msg>,
receiver: Mutex<loole::Receiver<Msg>>,
} }
struct Services { struct Services {
@ -72,8 +76,9 @@ pub enum SendingEvent {
#[async_trait] #[async_trait]
impl crate::Service for Service { impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> { fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
let (sender, receiver) = loole::unbounded(); let num_senders = num_senders(&args);
Ok(Arc::new(Self { Ok(Arc::new(Self {
db: Data::new(&args),
server: args.server.clone(), server: args.server.clone(),
services: Services { services: Services {
client: args.depend::<client::Service>("client"), client: args.depend::<client::Service>("client"),
@ -91,20 +96,41 @@ impl crate::Service for Service {
pusher: args.depend::<pusher::Service>("pusher"), pusher: args.depend::<pusher::Service>("pusher"),
server_keys: args.depend::<server_keys::Service>("server_keys"), server_keys: args.depend::<server_keys::Service>("server_keys"),
}, },
db: Data::new(&args), channels: (0..num_senders).map(|_| loole::unbounded()).collect(),
sender,
receiver: Mutex::new(receiver),
})) }))
} }
async fn worker(self: Arc<Self>) -> Result<()> { async fn worker(self: Arc<Self>) -> Result {
// trait impl can't be split between files so this just glues to mod sender let mut senders =
self.sender().await self.channels
.iter()
.enumerate()
.fold(JoinSet::new(), |mut joinset, (id, _)| {
let self_ = self.clone();
let runtime = self.server.runtime();
let _abort = joinset.spawn_on(self_.sender(id).boxed(), runtime);
joinset
});
while let Some(ret) = senders.join_next_with_id().await {
match ret {
| Ok((id, _)) => {
debug!(?id, "sender worker finished");
},
| Err(error) => {
error!(id = ?error.id(), ?error, "sender worker finished");
},
};
}
Ok(())
} }
fn interrupt(&self) { fn interrupt(&self) {
if !self.sender.is_closed() { for (sender, _) in &self.channels {
self.sender.close(); if !sender.is_closed() {
sender.close();
}
} }
} }
@ -157,7 +183,7 @@ impl Service {
let _cork = self.db.db.cork(); let _cork = self.db.db.cork();
let requests = servers let requests = servers
.map(|server| { .map(|server| {
(Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.to_owned())) (Destination::Federation(server.into()), SendingEvent::Pdu(pdu_id.to_owned()))
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
.await; .await;
@ -173,7 +199,7 @@ impl Service {
#[tracing::instrument(skip(self, server, serialized), level = "debug")] #[tracing::instrument(skip(self, server, serialized), level = "debug")]
pub fn send_edu_server(&self, server: &ServerName, serialized: Vec<u8>) -> Result<()> { pub fn send_edu_server(&self, server: &ServerName, serialized: Vec<u8>) -> Result<()> {
let dest = Destination::Normal(server.to_owned()); let dest = Destination::Federation(server.to_owned());
let event = SendingEvent::Edu(serialized); let event = SendingEvent::Edu(serialized);
let _cork = self.db.db.cork(); let _cork = self.db.db.cork();
let keys = self.db.queue_requests(once((&event, &dest))); let keys = self.db.queue_requests(once((&event, &dest)));
@ -203,7 +229,10 @@ impl Service {
let _cork = self.db.db.cork(); let _cork = self.db.db.cork();
let requests = servers let requests = servers
.map(|server| { .map(|server| {
(Destination::Normal(server.to_owned()), SendingEvent::Edu(serialized.clone())) (
Destination::Federation(server.to_owned()),
SendingEvent::Edu(serialized.clone()),
)
}) })
.collect::<Vec<_>>() .collect::<Vec<_>>()
.await; .await;
@ -235,7 +264,7 @@ impl Service {
{ {
servers servers
.map(ToOwned::to_owned) .map(ToOwned::to_owned)
.map(Destination::Normal) .map(Destination::Federation)
.map(Ok) .map(Ok)
.ready_try_for_each(|dest| { .ready_try_for_each(|dest| {
self.dispatch(Msg { self.dispatch(Msg {
@ -327,9 +356,49 @@ impl Service {
} }
} }
fn dispatch(&self, msg: Msg) -> Result<()> { fn dispatch(&self, msg: Msg) -> Result {
debug_assert!(!self.sender.is_full(), "channel full"); let shard = self.shard_id(&msg.dest);
debug_assert!(!self.sender.is_closed(), "channel closed"); let sender = &self
self.sender.send(msg).map_err(|e| err!("{e}")) .channels
.get(shard)
.expect("missing sender worker channels")
.0;
debug_assert!(!sender.is_full(), "channel full");
debug_assert!(!sender.is_closed(), "channel closed");
sender.send(msg).map_err(|e| err!("{e}"))
}
pub(super) fn shard_id(&self, dest: &Destination) -> usize {
if self.channels.len() <= 1 {
return 0;
}
let mut hash = DefaultHasher::default();
dest.hash(&mut hash);
let hash: u64 = hash.finish();
let hash = usize_from_u64_truncated(hash);
let chans = self.channels.len().max(1);
hash.overflowing_rem(chans).0
} }
} }
fn num_senders(args: &crate::Args<'_>) -> usize {
const MIN_SENDERS: usize = 1;
// Limit the number of senders to the number of workers threads or number of
// cores, conservatively.
let max_senders = args
.server
.metrics
.num_workers()
.min(available_parallelism());
// If the user doesn't override the default 0, this is intended to then default
// to 1 for now as multiple senders is experimental.
args.server
.config
.sender_workers
.clamp(MIN_SENDERS, max_senders)
}

View file

@ -24,7 +24,10 @@ use crate::{
}; };
impl super::Service { impl super::Service {
#[tracing::instrument(skip_all, level = "debug")] #[tracing::instrument(
level = "debug"
skip(self, client, request),
)]
pub async fn send<T>( pub async fn send<T>(
&self, &self,
client: &Client, client: &Client,

View file

@ -1,7 +1,10 @@
use std::{ use std::{
collections::{BTreeMap, HashMap, HashSet}, collections::{BTreeMap, HashMap, HashSet},
fmt::Debug, fmt::Debug,
sync::atomic::{AtomicU64, AtomicUsize, Ordering}, sync::{
atomic::{AtomicU64, AtomicUsize, Ordering},
Arc,
},
time::{Duration, Instant}, time::{Duration, Instant},
}; };
@ -66,29 +69,56 @@ pub const PDU_LIMIT: usize = 50;
pub const EDU_LIMIT: usize = 100; pub const EDU_LIMIT: usize = 100;
impl Service { impl Service {
#[tracing::instrument(skip_all, level = "debug")] #[tracing::instrument(skip(self), level = "debug")]
pub(super) async fn sender(&self) -> Result<()> { pub(super) async fn sender(self: Arc<Self>, id: usize) -> Result {
let mut statuses: CurTransactionStatus = CurTransactionStatus::new(); let mut statuses: CurTransactionStatus = CurTransactionStatus::new();
let mut futures: SendingFutures<'_> = FuturesUnordered::new(); let mut futures: SendingFutures<'_> = FuturesUnordered::new();
let receiver = self.receiver.lock().await;
self.initial_requests(&mut futures, &mut statuses).await; self.startup_netburst(id, &mut futures, &mut statuses)
while !receiver.is_closed() { .boxed()
tokio::select! { .await;
request = receiver.recv_async() => match request {
Ok(request) => self.handle_request(request, &mut futures, &mut statuses).await, self.work_loop(id, &mut futures, &mut statuses).await;
Err(_) => break,
}, self.finish_responses(&mut futures).boxed().await;
Some(response) = futures.next() => {
self.handle_response(response, &mut futures, &mut statuses).await;
},
}
}
self.finish_responses(&mut futures).await;
Ok(()) Ok(())
} }
#[tracing::instrument(
name = "work",
level = "trace"
skip_all,
fields(
futures = %futures.len(),
statuses = %statuses.len(),
),
)]
async fn work_loop<'a>(
&'a self,
id: usize,
futures: &mut SendingFutures<'a>,
statuses: &mut CurTransactionStatus,
) {
let receiver = self
.channels
.get(id)
.map(|(_, receiver)| receiver.clone())
.expect("Missing channel for sender worker");
loop {
tokio::select! {
Some(response) = futures.next() => {
self.handle_response(response, futures, statuses).await;
},
request = receiver.recv_async() => match request {
Ok(request) => self.handle_request(request, futures, statuses).await,
Err(_) => return,
},
}
}
}
#[tracing::instrument(name = "response", level = "debug", skip_all)]
async fn handle_response<'a>( async fn handle_response<'a>(
&'a self, &'a self,
response: SendingResult, response: SendingResult,
@ -138,13 +168,14 @@ impl Service {
self.db.mark_as_active(new_events.iter()); self.db.mark_as_active(new_events.iter());
let new_events_vec = new_events.into_iter().map(|(_, event)| event).collect(); let new_events_vec = new_events.into_iter().map(|(_, event)| event).collect();
futures.push(self.send_events(dest.clone(), new_events_vec).boxed()); futures.push(self.send_events(dest.clone(), new_events_vec));
} else { } else {
statuses.remove(dest); statuses.remove(dest);
} }
} }
#[allow(clippy::needless_pass_by_ref_mut)] #[allow(clippy::needless_pass_by_ref_mut)]
#[tracing::instrument(name = "request", level = "debug", skip_all)]
async fn handle_request<'a>( async fn handle_request<'a>(
&'a self, &'a self,
msg: Msg, msg: Msg,
@ -154,13 +185,19 @@ impl Service {
let iv = vec![(msg.queue_id, msg.event)]; let iv = vec![(msg.queue_id, msg.event)];
if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await { if let Ok(Some(events)) = self.select_events(&msg.dest, iv, statuses).await {
if !events.is_empty() { if !events.is_empty() {
futures.push(self.send_events(msg.dest, events).boxed()); futures.push(self.send_events(msg.dest, events));
} else { } else {
statuses.remove(&msg.dest); statuses.remove(&msg.dest);
} }
} }
} }
#[tracing::instrument(
name = "finish",
level = "info",
skip_all,
fields(futures = %futures.len()),
)]
async fn finish_responses<'a>(&'a self, futures: &mut SendingFutures<'a>) { async fn finish_responses<'a>(&'a self, futures: &mut SendingFutures<'a>) {
use tokio::{ use tokio::{
select, select,
@ -183,9 +220,16 @@ impl Service {
} }
} }
#[tracing::instrument(
name = "netburst",
level = "debug",
skip_all,
fields(futures = %futures.len()),
)]
#[allow(clippy::needless_pass_by_ref_mut)] #[allow(clippy::needless_pass_by_ref_mut)]
async fn initial_requests<'a>( async fn startup_netburst<'a>(
&'a self, &'a self,
id: usize,
futures: &mut SendingFutures<'a>, futures: &mut SendingFutures<'a>,
statuses: &mut CurTransactionStatus, statuses: &mut CurTransactionStatus,
) { ) {
@ -195,6 +239,10 @@ impl Service {
let mut active = self.db.active_requests().boxed(); let mut active = self.db.active_requests().boxed();
while let Some((key, event, dest)) = active.next().await { while let Some((key, event, dest)) = active.next().await {
if self.shard_id(&dest) != id {
continue;
}
let entry = txns.entry(dest.clone()).or_default(); let entry = txns.entry(dest.clone()).or_default();
if self.server.config.startup_netburst_keep >= 0 && entry.len() >= keep { if self.server.config.startup_netburst_keep >= 0 && entry.len() >= keep {
warn!("Dropping unsent event {dest:?} {:?}", String::from_utf8_lossy(&key)); warn!("Dropping unsent event {dest:?} {:?}", String::from_utf8_lossy(&key));
@ -207,19 +255,27 @@ impl Service {
for (dest, events) in txns { for (dest, events) in txns {
if self.server.config.startup_netburst && !events.is_empty() { if self.server.config.startup_netburst && !events.is_empty() {
statuses.insert(dest.clone(), TransactionStatus::Running); statuses.insert(dest.clone(), TransactionStatus::Running);
futures.push(self.send_events(dest.clone(), events).boxed()); futures.push(self.send_events(dest.clone(), events));
} }
} }
} }
#[tracing::instrument(skip_all, level = "debug")] #[tracing::instrument(
name = "select",,
level = "debug",
skip_all,
fields(
?dest,
new_events = %new_events.len(),
)
)]
async fn select_events( async fn select_events(
&self, &self,
dest: &Destination, dest: &Destination,
new_events: Vec<QueueItem>, // Events we want to send: event and full key new_events: Vec<QueueItem>, // Events we want to send: event and full key
statuses: &mut CurTransactionStatus, statuses: &mut CurTransactionStatus,
) -> Result<Option<Vec<SendingEvent>>> { ) -> Result<Option<Vec<SendingEvent>>> {
let (allow, retry) = self.select_events_current(dest.clone(), statuses)?; let (allow, retry) = self.select_events_current(dest, statuses)?;
// Nothing can be done for this remote, bail out. // Nothing can be done for this remote, bail out.
if !allow { if !allow {
@ -249,7 +305,7 @@ impl Service {
} }
// Add EDU's into the transaction // Add EDU's into the transaction
if let Destination::Normal(server_name) = dest { if let Destination::Federation(server_name) = dest {
if let Ok((select_edus, last_count)) = self.select_edus(server_name).await { if let Ok((select_edus, last_count)) = self.select_edus(server_name).await {
debug_assert!(select_edus.len() <= EDU_LIMIT, "exceeded edus limit"); debug_assert!(select_edus.len() <= EDU_LIMIT, "exceeded edus limit");
events.extend(select_edus.into_iter().map(SendingEvent::Edu)); events.extend(select_edus.into_iter().map(SendingEvent::Edu));
@ -260,10 +316,9 @@ impl Service {
Ok(Some(events)) Ok(Some(events))
} }
#[tracing::instrument(skip_all, level = "debug")]
fn select_events_current( fn select_events_current(
&self, &self,
dest: Destination, dest: &Destination,
statuses: &mut CurTransactionStatus, statuses: &mut CurTransactionStatus,
) -> Result<(bool, bool)> { ) -> Result<(bool, bool)> {
let (mut allow, mut retry) = (true, false); let (mut allow, mut retry) = (true, false);
@ -292,7 +347,11 @@ impl Service {
Ok((allow, retry)) Ok((allow, retry))
} }
#[tracing::instrument(skip_all, level = "debug")] #[tracing::instrument(
name = "edus",,
level = "debug",
skip_all,
)]
async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { async fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> {
// selection window // selection window
let since = self.db.get_latest_educount(server_name).await; let since = self.db.get_latest_educount(server_name).await;
@ -329,7 +388,12 @@ impl Service {
Ok((events, max_edu_count.load(Ordering::Acquire))) Ok((events, max_edu_count.load(Ordering::Acquire)))
} }
/// Look for presence /// Look for device changes
#[tracing::instrument(
name = "device_changes",
level = "trace",
skip(self, server_name, max_edu_count)
)]
async fn select_edus_device_changes( async fn select_edus_device_changes(
&self, &self,
server_name: &ServerName, server_name: &ServerName,
@ -386,6 +450,11 @@ impl Service {
} }
/// Look for read receipts in this room /// Look for read receipts in this room
#[tracing::instrument(
name = "receipts",
level = "trace",
skip(self, server_name, max_edu_count)
)]
async fn select_edus_receipts( async fn select_edus_receipts(
&self, &self,
server_name: &ServerName, server_name: &ServerName,
@ -420,6 +489,7 @@ impl Service {
} }
/// Look for read receipts in this room /// Look for read receipts in this room
#[tracing::instrument(name = "receipts", level = "trace", skip(self, since, max_edu_count))]
async fn select_edus_receipts_room( async fn select_edus_receipts_room(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
@ -484,6 +554,11 @@ impl Service {
} }
/// Look for presence /// Look for presence
#[tracing::instrument(
name = "presence",
level = "trace",
skip(self, server_name, max_edu_count)
)]
async fn select_edus_presence( async fn select_edus_presence(
&self, &self,
server_name: &ServerName, server_name: &ServerName,
@ -554,29 +629,33 @@ impl Service {
Some(presence_content) Some(presence_content)
} }
async fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingResult { fn send_events(&self, dest: Destination, events: Vec<SendingEvent>) -> SendingFuture<'_> {
//debug_assert!(!events.is_empty(), "sending empty transaction"); //debug_assert!(!events.is_empty(), "sending empty transaction");
match dest { match dest {
| Destination::Normal(ref server) => | Destination::Federation(server) =>
self.send_events_dest_normal(&dest, server, events).await, self.send_events_dest_federation(server, events).boxed(),
| Destination::Appservice(ref id) => | Destination::Appservice(id) => self.send_events_dest_appservice(id, events).boxed(),
self.send_events_dest_appservice(&dest, id, events).await, | Destination::Push(user_id, pushkey) =>
| Destination::Push(ref userid, ref pushkey) => self.send_events_dest_push(user_id, pushkey, events).boxed(),
self.send_events_dest_push(&dest, userid, pushkey, events)
.await,
} }
} }
#[tracing::instrument(skip(self, dest, events), name = "appservice")] #[tracing::instrument(
name = "appservice",
level = "debug",
skip(self, events),
fields(
events = %events.len(),
),
)]
async fn send_events_dest_appservice( async fn send_events_dest_appservice(
&self, &self,
dest: &Destination, id: String,
id: &str,
events: Vec<SendingEvent>, events: Vec<SendingEvent>,
) -> SendingResult { ) -> SendingResult {
let Some(appservice) = self.services.appservice.get_registration(id).await else { let Some(appservice) = self.services.appservice.get_registration(&id).await else {
return Err(( return Err((
dest.clone(), Destination::Appservice(id.clone()),
err!(Database(warn!(?id, "Missing appservice registration"))), err!(Database(warn!(?id, "Missing appservice registration"))),
)); ));
}; };
@ -633,23 +712,29 @@ impl Service {
) )
.await .await
{ {
| Ok(_) => Ok(dest.clone()), | Ok(_) => Ok(Destination::Appservice(id)),
| Err(e) => Err((dest.clone(), e)), | Err(e) => Err((Destination::Appservice(id), e)),
} }
} }
#[tracing::instrument(skip(self, dest, events), name = "push")] #[tracing::instrument(
name = "push",
level = "info",
skip(self, events),
fields(
events = %events.len(),
),
)]
async fn send_events_dest_push( async fn send_events_dest_push(
&self, &self,
dest: &Destination, user_id: OwnedUserId,
userid: &OwnedUserId, pushkey: String,
pushkey: &str,
events: Vec<SendingEvent>, events: Vec<SendingEvent>,
) -> SendingResult { ) -> SendingResult {
let Ok(pusher) = self.services.pusher.get_pusher(userid, pushkey).await else { let Ok(pusher) = self.services.pusher.get_pusher(&user_id, &pushkey).await else {
return Err(( return Err((
dest.clone(), Destination::Push(user_id.clone(), pushkey.clone()),
err!(Database(error!(?userid, ?pushkey, "Missing pusher"))), err!(Database(error!(?user_id, ?pushkey, "Missing pusher"))),
)); ));
}; };
@ -677,17 +762,17 @@ impl Service {
let rules_for_user = self let rules_for_user = self
.services .services
.account_data .account_data
.get_global(userid, GlobalAccountDataEventType::PushRules) .get_global(&user_id, GlobalAccountDataEventType::PushRules)
.await .await
.map_or_else( .map_or_else(
|_| push::Ruleset::server_default(userid), |_| push::Ruleset::server_default(&user_id),
|ev: PushRulesEvent| ev.content.global, |ev: PushRulesEvent| ev.content.global,
); );
let unread: UInt = self let unread: UInt = self
.services .services
.user .user
.notification_count(userid, &pdu.room_id) .notification_count(&user_id, &pdu.room_id)
.await .await
.try_into() .try_into()
.expect("notification count can't go that high"); .expect("notification count can't go that high");
@ -695,19 +780,25 @@ impl Service {
let _response = self let _response = self
.services .services
.pusher .pusher
.send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) .send_push_notice(&user_id, unread, &pusher, rules_for_user, &pdu)
.await .await
.map_err(|e| (dest.clone(), e)); .map_err(|e| (Destination::Push(user_id.clone(), pushkey.clone()), e));
} }
Ok(dest.clone()) Ok(Destination::Push(user_id, pushkey))
} }
#[tracing::instrument(skip(self, dest, events), name = "", level = "debug")] #[tracing::instrument(
async fn send_events_dest_normal( name = "fed",
level = "debug",
skip(self, events),
fields(
events = %events.len(),
),
)]
async fn send_events_dest_federation(
&self, &self,
dest: &Destination, server: OwnedServerName,
server: &OwnedServerName,
events: Vec<SendingEvent>, events: Vec<SendingEvent>,
) -> SendingResult { ) -> SendingResult {
let mut pdu_jsons = Vec::with_capacity( let mut pdu_jsons = Vec::with_capacity(
@ -759,7 +850,7 @@ impl Service {
}; };
let client = &self.services.client.sender; let client = &self.services.client.sender;
self.send(client, server, request) self.send(client, &server, request)
.await .await
.inspect(|response| { .inspect(|response| {
response response
@ -770,8 +861,8 @@ impl Service {
|(pdu_id, res)| warn!(%txn_id, %server, "error sending PDU {pdu_id} to remote server: {res:?}"), |(pdu_id, res)| warn!(%txn_id, %server, "error sending PDU {pdu_id} to remote server: {res:?}"),
); );
}) })
.map(|_| dest.clone()) .map_err(|e| (Destination::Federation(server.clone()), e))
.map_err(|e| (dest.clone(), e)) .map(|_| Destination::Federation(server))
} }
/// This does not return a full `Pdu` it is only to satisfy ruma's types. /// This does not return a full `Pdu` it is only to satisfy ruma's types.