split out, dedup, cleanup sending service methods

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-03-30 02:10:08 -07:00 committed by June
parent 3c09313f79
commit a87e7d8e17

View file

@ -55,34 +55,6 @@ pub enum OutgoingKind {
Normal(OwnedServerName), Normal(OwnedServerName),
} }
impl OutgoingKind {
#[tracing::instrument(skip(self))]
pub fn get_prefix(&self) -> Vec<u8> {
let mut prefix = match self {
OutgoingKind::Appservice(server) => {
let mut p = b"+".to_vec();
p.extend_from_slice(server.as_bytes());
p
},
OutgoingKind::Push(user, pushkey) => {
let mut p = b"$".to_vec();
p.extend_from_slice(user.as_bytes());
p.push(0xFF);
p.extend_from_slice(pushkey.as_bytes());
p
},
OutgoingKind::Normal(server) => {
let mut p = Vec::new();
p.extend_from_slice(server.as_bytes());
p
},
};
prefix.push(0xFF);
prefix
}
}
#[derive(Clone, Debug, PartialEq, Eq, Hash)] #[derive(Clone, Debug, PartialEq, Eq, Hash)]
#[allow(clippy::module_name_repetitions)] #[allow(clippy::module_name_repetitions)]
pub enum SendingEventType { pub enum SendingEventType {
@ -106,6 +78,7 @@ impl Service {
receiver: Mutex::new(receiver), receiver: Mutex::new(receiver),
maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)),
startup_netburst: config.startup_netburst, startup_netburst: config.startup_netburst,
timeout: config.sender_timeout,
}) })
} }
@ -251,6 +224,41 @@ impl Service {
Ok(()) Ok(())
} }
#[tracing::instrument(skip(self, destination, request))]
pub async fn send_federation_request<T>(&self, destination: &ServerName, request: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
{
let permit = self.maximum_requests.acquire().await;
let timeout = Duration::from_secs(self.timeout);
let response = tokio::time::timeout(timeout, send::send_request(destination, request))
.await
.map_err(|_| {
warn!("Timeout after 300 seconds waiting for server response of {destination}");
Error::BadServerResponse("Timeout after 300 seconds waiting for server response")
})?;
drop(permit);
response
}
/// Sends a request to an appservice
///
/// Only returns None if there is no url specified in the appservice
/// registration file
pub async fn send_appservice_request<T>(
&self, registration: Registration, request: T,
) -> Option<Result<T::IncomingResponse>>
where
T: OutgoingRequest + Debug,
{
let permit = self.maximum_requests.acquire().await;
let response = appservice::send_request(registration, request).await;
drop(permit);
response
}
pub fn start_handler(self: &Arc<Self>) { pub fn start_handler(self: &Arc<Self>) {
let self2 = Arc::clone(self); let self2 = Arc::clone(self);
tokio::spawn(async move { tokio::spawn(async move {
@ -286,7 +294,7 @@ impl Service {
for (outgoing_kind, events) in initial_transactions { for (outgoing_kind, events) in initial_transactions {
current_transaction_status.insert(outgoing_kind.clone(), TransactionStatus::Running); current_transaction_status.insert(outgoing_kind.clone(), TransactionStatus::Running);
futures.push(Self::handle_events(outgoing_kind.clone(), events)); futures.push(handle_events(outgoing_kind.clone(), events));
} }
} }
@ -299,18 +307,19 @@ impl Service {
self.db.delete_all_active_requests_for(&outgoing_kind)?; self.db.delete_all_active_requests_for(&outgoing_kind)?;
// Find events that have been added since starting the last request // Find events that have been added since starting the last request
let new_events = self.db.queued_requests(&outgoing_kind).filter_map(Result::ok).take(30).collect::<Vec<_>>(); let new_events = self
.db
.queued_requests(&outgoing_kind)
.filter_map(Result::ok)
.take(30).collect::<Vec<_>>();
if !new_events.is_empty() { if !new_events.is_empty() {
// Insert pdus we found // Insert pdus we found
self.db.mark_as_active(&new_events)?; self.db.mark_as_active(&new_events)?;
futures.push(handle_events(
futures.push(
Self::handle_events(
outgoing_kind.clone(), outgoing_kind.clone(),
new_events.into_iter().map(|(event, _)| event).collect(), new_events.into_iter().map(|(event, _)| event).collect(),
) ));
);
} else { } else {
current_transaction_status.remove(&outgoing_kind); current_transaction_status.remove(&outgoing_kind);
} }
@ -333,7 +342,7 @@ impl Service {
vec![(event, key)], vec![(event, key)],
&mut current_transaction_status, &mut current_transaction_status,
) { ) {
futures.push(Self::handle_events(outgoing_kind, events)); futures.push(handle_events(outgoing_kind, events));
} }
} }
} }
@ -347,50 +356,28 @@ impl Service {
new_events: Vec<(SendingEventType, Vec<u8>)>, // Events we want to send: event and full key new_events: Vec<(SendingEventType, Vec<u8>)>, // Events we want to send: event and full key
current_transaction_status: &mut HashMap<OutgoingKind, TransactionStatus>, current_transaction_status: &mut HashMap<OutgoingKind, TransactionStatus>,
) -> Result<Option<Vec<SendingEventType>>> { ) -> Result<Option<Vec<SendingEventType>>> {
let mut retry = false; let (allow, retry) = self.select_events_current(outgoing_kind.clone(), current_transaction_status)?;
let mut allow = true;
let _cork = services().globals.db.cork();
let entry = current_transaction_status.entry(outgoing_kind.clone());
entry
.and_modify(|e| match e {
TransactionStatus::Running | TransactionStatus::Retrying(_) => {
allow = false; // already running
},
TransactionStatus::Failed(tries, time) => {
// Fail if a request has failed recently (exponential backoff)
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
}
if time.elapsed() < min_elapsed_duration {
allow = false;
} else {
retry = true;
*e = TransactionStatus::Retrying(*tries);
}
},
})
.or_insert(TransactionStatus::Running);
// Nothing can be done for this remote, bail out.
if !allow { if !allow {
return Ok(None); return Ok(None);
} }
let _cork = services().globals.db.cork();
let mut events = Vec::new(); let mut events = Vec::new();
// Must retry any previous transaction for this remote.
if retry { if retry {
// We retry the previous transaction self.db
for (_, e) in self
.db
.active_requests_for(outgoing_kind) .active_requests_for(outgoing_kind)
.filter_map(Result::ok) .filter_map(Result::ok)
{ .for_each(|(_, e)| events.push(e));
events.push(e);
return Ok(Some(events));
} }
} else {
// Compose the next transaction
let _cork = services().globals.db.cork();
if !new_events.is_empty() { if !new_events.is_empty() {
self.db.mark_as_active(&new_events)?; self.db.mark_as_active(&new_events)?;
for (e, _) in new_events { for (e, _) in new_events {
@ -398,18 +385,46 @@ impl Service {
} }
} }
// Add EDU's into the transaction
if let OutgoingKind::Normal(server_name) = outgoing_kind { if let OutgoingKind::Normal(server_name) = outgoing_kind {
if let Ok((select_edus, last_count)) = self.select_edus(server_name) { if let Ok((select_edus, last_count)) = self.select_edus(server_name) {
events.extend(select_edus.into_iter().map(SendingEventType::Edu)); events.extend(select_edus.into_iter().map(SendingEventType::Edu));
self.db.set_latest_educount(server_name, last_count)?; self.db.set_latest_educount(server_name, last_count)?;
} }
} }
}
Ok(Some(events)) Ok(Some(events))
} }
#[tracing::instrument(skip(self, outgoing_kind, current_transaction_status))]
fn select_events_current(
&self, outgoing_kind: OutgoingKind, current_transaction_status: &mut HashMap<OutgoingKind, TransactionStatus>,
) -> Result<(bool, bool)> {
let (mut allow, mut retry) = (true, false);
current_transaction_status
.entry(outgoing_kind)
.and_modify(|e| match e {
TransactionStatus::Failed(tries, time) => {
// Fail if a request has failed recently (exponential backoff)
const MAX_DURATION: Duration = Duration::from_secs(60 * 60 * 24);
let mut min_elapsed_duration = Duration::from_secs(self.timeout) * (*tries) * (*tries);
min_elapsed_duration = std::cmp::min(min_elapsed_duration, MAX_DURATION);
if time.elapsed() < min_elapsed_duration {
allow = false;
} else {
retry = true;
*e = TransactionStatus::Retrying(*tries);
}
},
TransactionStatus::Running | TransactionStatus::Retrying(_) => {
allow = false; // already running
},
})
.or_insert(TransactionStatus::Running);
Ok((allow, retry))
}
#[tracing::instrument(skip(self, server_name))] #[tracing::instrument(skip(self, server_name))]
pub fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> { pub fn select_edus(&self, server_name: &ServerName) -> Result<(Vec<Vec<u8>>, u64)> {
// u64: count of last edu // u64: count of last edu
@ -418,7 +433,7 @@ impl Service {
let mut max_edu_count = since; let mut max_edu_count = since;
let mut device_list_changes = HashSet::new(); let mut device_list_changes = HashSet::new();
'outer: for room_id in services().rooms.state_cache.server_rooms(server_name) { for room_id in services().rooms.state_cache.server_rooms(server_name) {
let room_id = room_id?; let room_id = room_id?;
// Look for device list updates in this room // Look for device list updates in this room
device_list_changes.extend( device_list_changes.extend(
@ -428,21 +443,54 @@ impl Service {
.filter_map(Result::ok) .filter_map(Result::ok)
.filter(|user_id| user_id.server_name() == services().globals.server_name()), .filter(|user_id| user_id.server_name() == services().globals.server_name()),
); );
if !select_edus_presence(&room_id, since, &mut max_edu_count, &mut events)? {
break;
}
if !select_edus_receipts(&room_id, since, &mut max_edu_count, &mut events)? {
break;
}
}
for user_id in device_list_changes {
// Empty prev id forces synapse to resync; because synapse resyncs,
// we can just insert placeholder data
let edu = Edu::DeviceListUpdate(DeviceListUpdateContent {
user_id,
device_id: device_id!("placeholder").to_owned(),
device_display_name: Some("Placeholder".to_owned()),
stream_id: uint!(1),
prev_id: Vec::new(),
deleted: None,
keys: None,
});
events.push(serde_json::to_vec(&edu).expect("json can be serialized"));
}
Ok((events, max_edu_count))
}
}
/// Look for presence [in this room] <--- XXX
#[tracing::instrument(skip(room_id, since, max_edu_count, events))]
pub fn select_edus_presence(
room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> {
if !services().globals.allow_outgoing_presence() {
return Ok(true);
}
if services().globals.allow_outgoing_presence() {
// Look for presence updates in this room // Look for presence updates in this room
let mut presence_updates = Vec::new(); let mut presence_updates = Vec::new();
for (user_id, count, presence_event) in services() for (user_id, count, presence_event) in services()
.rooms .rooms
.edus .edus
.presence .presence
.presence_since(&room_id, since) .presence_since(room_id, since)
{ {
if count > max_edu_count { if count > *max_edu_count {
max_edu_count = count; *max_edu_count = count;
} }
if user_id.server_name() != services().globals.server_name() { if user_id.server_name() != services().globals.server_name() {
continue; continue;
} }
@ -461,21 +509,26 @@ impl Service {
let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); let presence_content = Edu::Presence(PresenceContent::new(presence_updates));
events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized")); events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized"));
}
// Look for read receipts in this room Ok(true)
}
/// Look for read receipts in this room
#[tracing::instrument(skip(room_id, since, max_edu_count, events))]
pub fn select_edus_receipts(
room_id: &RoomId, since: u64, max_edu_count: &mut u64, events: &mut Vec<Vec<u8>>,
) -> Result<bool> {
for r in services() for r in services()
.rooms .rooms
.edus .edus
.read_receipt .read_receipt
.readreceipts_since(&room_id, since) .readreceipts_since(room_id, since)
{ {
let (user_id, count, read_receipt) = r?; let (user_id, count, read_receipt) = r?;
if count > max_edu_count { if count > *max_edu_count {
max_edu_count = count; *max_edu_count = count;
} }
if user_id.server_name() != services().globals.server_name() { if user_id.server_name() != services().globals.server_name() {
continue; continue;
} }
@ -510,7 +563,7 @@ impl Service {
}; };
let mut receipts = BTreeMap::new(); let mut receipts = BTreeMap::new();
receipts.insert(room_id.clone(), receipt_map); receipts.insert(room_id.to_owned(), receipt_map);
Edu::Receipt(ReceiptContent { Edu::Receipt(ReceiptContent {
receipts, receipts,
@ -523,36 +576,28 @@ impl Service {
events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); events.push(serde_json::to_vec(&federation_event).expect("json can be serialized"));
if events.len() >= 20 { if events.len() >= 20 {
break 'outer; return Ok(false);
}
} }
} }
for user_id in device_list_changes { Ok(true)
// Empty prev id forces synapse to resync: https://github.com/matrix-org/synapse/blob/98aec1cc9da2bd6b8e34ffb282c85abf9b8b42ca/synapse/handlers/device.py#L767 }
// Because synapse resyncs, we can just insert placeholder data
let edu = Edu::DeviceListUpdate(DeviceListUpdateContent {
user_id,
device_id: device_id!("placeholder").to_owned(),
device_display_name: Some("Placeholder".to_owned()),
stream_id: uint!(1),
prev_id: Vec::new(),
deleted: None,
keys: None,
});
events.push(serde_json::to_vec(&edu).expect("json can be serialized")); #[tracing::instrument(skip(events, kind))]
} async fn handle_events(
Ok((events, max_edu_count))
}
#[tracing::instrument(skip(events, kind))]
async fn handle_events(
kind: OutgoingKind, events: Vec<SendingEventType>, kind: OutgoingKind, events: Vec<SendingEventType>,
) -> Result<OutgoingKind, (OutgoingKind, Error)> { ) -> Result<OutgoingKind, (OutgoingKind, Error)> {
match &kind { match kind {
OutgoingKind::Appservice(id) => { OutgoingKind::Appservice(ref id) => handle_events_kind_appservice(&kind, id, events).await,
OutgoingKind::Push(ref userid, ref pushkey) => handle_events_kind_push(&kind, userid, pushkey, events).await,
OutgoingKind::Normal(ref server) => handle_events_kind_normal(&kind, server, events).await,
}
}
#[tracing::instrument(skip(kind, events))]
async fn handle_events_kind_appservice(
kind: &OutgoingKind, id: &String, events: Vec<SendingEventType>,
) -> Result<OutgoingKind, (OutgoingKind, Error)> {
let mut pdu_jsons = Vec::new(); let mut pdu_jsons = Vec::new();
for event in &events { for event in &events {
@ -567,9 +612,7 @@ impl Service {
.ok_or_else(|| { .ok_or_else(|| {
( (
kind.clone(), kind.clone(),
Error::bad_database( Error::bad_database("[Appservice] Event in servernameevent_data not found in db."),
"[Appservice] Event in servernameevent_data not found in db.",
),
) )
})? })?
.to_room_event(), .to_room_event(),
@ -620,8 +663,12 @@ impl Service {
drop(permit); drop(permit);
response response
}, }
OutgoingKind::Push(userid, pushkey) => {
#[tracing::instrument(skip(kind, events))]
async fn handle_events_kind_push(
kind: &OutgoingKind, userid: &OwnedUserId, pushkey: &String, events: Vec<SendingEventType>,
) -> Result<OutgoingKind, (OutgoingKind, Error)> {
let mut pdus = Vec::new(); let mut pdus = Vec::new();
for event in &events { for event in &events {
@ -636,9 +683,7 @@ impl Service {
.ok_or_else(|| { .ok_or_else(|| {
( (
kind.clone(), kind.clone(),
Error::bad_database( Error::bad_database("[Push] Event in servernamevent_datas not found in db."),
"[Push] Event in servernamevent_datas not found in db.",
),
) )
})?, })?,
); );
@ -663,7 +708,7 @@ impl Service {
let Some(pusher) = services() let Some(pusher) = services()
.pusher .pusher
.get_pusher(userid, pushkey) .get_pusher(userid, pushkey)
.map_err(|e| (OutgoingKind::Push(userid.clone(), pushkey.clone()), e))? .map_err(|e| (kind.clone(), e))?
else { else {
continue; continue;
}; };
@ -694,9 +739,14 @@ impl Service {
drop(permit); drop(permit);
} }
Ok(OutgoingKind::Push(userid.clone(), pushkey.clone()))
}, Ok(kind.clone())
OutgoingKind::Normal(server) => { }
#[tracing::instrument(skip(kind, events))]
async fn handle_events_kind_normal(
kind: &OutgoingKind, server: &OwnedServerName, events: Vec<SendingEventType>,
) -> Result<OutgoingKind, (OutgoingKind, Error)> {
let mut edu_jsons = Vec::new(); let mut edu_jsons = Vec::new();
let mut pdu_jsons = Vec::new(); let mut pdu_jsons = Vec::new();
@ -709,14 +759,12 @@ impl Service {
.rooms .rooms
.timeline .timeline
.get_pdu_json_from_id(pdu_id) .get_pdu_json_from_id(pdu_id)
.map_err(|e| (OutgoingKind::Normal(server.clone()), e))? .map_err(|e| (kind.clone(), e))?
.ok_or_else(|| { .ok_or_else(|| {
error!("event not found: {server} {pdu_id:?}"); error!("event not found: {server} {pdu_id:?}");
( (
OutgoingKind::Normal(server.clone()), kind.clone(),
Error::bad_database( Error::bad_database("[Normal] Event in servernamevent_datas not found in db."),
"[Normal] Event in servernamevent_datas not found in db.",
),
) )
})?, })?,
); );
@ -763,80 +811,37 @@ impl Service {
} }
kind.clone() kind.clone()
}) })
.map_err(|e| (kind, e)); .map_err(|e| (kind.clone(), e));
drop(permit); drop(permit);
response response
}
impl OutgoingKind {
#[tracing::instrument(skip(self))]
pub fn get_prefix(&self) -> Vec<u8> {
let mut prefix = match self {
OutgoingKind::Appservice(server) => {
let mut p = b"+".to_vec();
p.extend_from_slice(server.as_bytes());
p
}, },
} OutgoingKind::Push(user, pushkey) => {
} let mut p = b"$".to_vec();
p.extend_from_slice(user.as_bytes());
p.push(0xFF);
p.extend_from_slice(pushkey.as_bytes());
p
},
OutgoingKind::Normal(server) => {
let mut p = Vec::new();
p.extend_from_slice(server.as_bytes());
p
},
};
prefix.push(0xFF);
#[tracing::instrument(skip(self, destination, request))] prefix
pub async fn send_federation_request<T>(&self, destination: &ServerName, request: T) -> Result<T::IncomingResponse>
where
T: OutgoingRequest + Debug,
{
if !services().globals.allow_federation() {
return Err(Error::bad_config("Federation is disabled."));
}
if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) {
info!(
"Destination {} is an IP literal, checking against IP range denylist.",
destination
);
let ip = IPAddress::parse(destination.host()).map_err(|e| {
warn!("Failed to parse IP literal from string: {}", e);
Error::BadServerResponse("Invalid IP address")
})?;
let cidr_ranges_s = services().globals.ip_range_denylist().to_vec();
let mut cidr_ranges: Vec<IPAddress> = Vec::new();
for cidr in cidr_ranges_s {
cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup"));
}
debug!("List of pushed CIDR ranges: {:?}", cidr_ranges);
for cidr in cidr_ranges {
if cidr.includes(&ip) {
return Err(Error::BadServerResponse("Not allowed to send requests to this IP"));
}
}
info!("IP literal {} is allowed.", destination);
}
debug!("Waiting for permit");
let permit = self.maximum_requests.acquire().await;
debug!("Got permit");
let response = tokio::time::timeout(Duration::from_secs(5 * 60), send::send_request(destination, request))
.await
.map_err(|_| {
warn!("Timeout after 300 seconds waiting for server response of {destination}");
Error::BadServerResponse("Timeout after 300 seconds waiting for server response")
})?;
drop(permit);
response
}
/// Sends a request to an appservice
///
/// Only returns None if there is no url specified in the appservice
/// registration file
pub async fn send_appservice_request<T>(
&self, registration: Registration, request: T,
) -> Option<Result<T::IncomingResponse>>
where
T: OutgoingRequest + Debug,
{
let permit = self.maximum_requests.acquire().await;
let response = appservice::send_request(registration, request).await;
drop(permit);
response
} }
} }