split send txn handler

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-03 00:06:00 +00:00
parent 473b29d524
commit df0328f43f

View file

@ -1,4 +1,4 @@
use std::{collections::BTreeMap, time::Instant}; use std::{collections::BTreeMap, net::IpAddr, time::Instant};
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::debug_warn; use conduit::debug_warn;
@ -6,12 +6,16 @@ use ruma::{
api::{ api::{
client::error::ErrorKind, client::error::ErrorKind,
federation::transactions::{ federation::transactions::{
edu::{DeviceListUpdateContent, DirectDeviceContent, Edu, SigningKeyUpdateContent}, edu::{
DeviceListUpdateContent, DirectDeviceContent, Edu, PresenceContent, ReceiptContent,
SigningKeyUpdateContent, TypingContent,
},
send_transaction_message, send_transaction_message,
}, },
}, },
events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType}, events::receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType},
to_device::DeviceIdOrAllDevices, to_device::DeviceIdOrAllDevices,
OwnedEventId, ServerName,
}; };
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::{debug, error, trace, warn}; use tracing::{debug, error, trace, warn};
@ -23,6 +27,8 @@ use crate::{
Error, Result, Ruma, Error, Result, Ruma,
}; };
type ResolvedMap = BTreeMap<OwnedEventId, Result<(), Error>>;
/// # `PUT /_matrix/federation/v1/send/{txnId}` /// # `PUT /_matrix/federation/v1/send/{txnId}`
/// ///
/// Push EDUs and PDUs to this server. /// Push EDUs and PDUs to this server.
@ -53,15 +59,39 @@ pub(crate) async fn send_transaction_message_route(
)); ));
} }
// This is all the auth_events that have been recursively fetched so they don't
// have to be deserialized over and over again.
// TODO: make this persist across requests but not in a DB Tree (in globals?)
// TODO: This could potentially also be some sort of trie (suffix tree) like
// structure so that once an auth event is known it would know (using indexes
// maybe) all of the auth events that it references.
// let mut auth_cache = EventMap::new();
let txn_start_time = Instant::now(); let txn_start_time = Instant::now();
trace!(
pdus = ?body.pdus.len(),
edus = ?body.edus.len(),
elapsed = ?txn_start_time.elapsed(),
id = ?body.transaction_id,
origin =?body.origin,
"Starting txn",
);
let resolved_map = handle_pdus(&client, &body, origin, &txn_start_time).await?;
handle_edus(&client, &body, origin).await?;
debug!(
pdus = ?body.pdus.len(),
edus = ?body.edus.len(),
elapsed = ?txn_start_time.elapsed(),
id = ?body.transaction_id,
origin =?body.origin,
"Finished txn",
);
Ok(send_transaction_message::v1::Response {
pdus: resolved_map
.into_iter()
.map(|(e, r)| (e, r.map_err(|e| e.sanitized_error())))
.collect(),
})
}
async fn handle_pdus(
_client: &IpAddr, body: &Ruma<send_transaction_message::v1::Request>, origin: &ServerName, txn_start_time: &Instant,
) -> Result<ResolvedMap> {
let mut parsed_pdus = Vec::with_capacity(body.pdus.len()); let mut parsed_pdus = Vec::with_capacity(body.pdus.len());
for pdu in &body.pdus { for pdu in &body.pdus {
parsed_pdus.push(match parse_incoming_pdu(pdu) { parsed_pdus.push(match parse_incoming_pdu(pdu) {
@ -76,15 +106,6 @@ pub(crate) async fn send_transaction_message_route(
// and hashes checks // and hashes checks
} }
trace!(
pdus = ?parsed_pdus.len(),
edus = ?body.edus.len(),
elapsed = ?txn_start_time.elapsed(),
id = ?body.transaction_id,
origin =?body.origin,
"Starting txn",
);
// We go through all the signatures we see on the PDUs and fetch the // We go through all the signatures we see on the PDUs and fetch the
// corresponding signing keys // corresponding signing keys
let pub_key_map = RwLock::new(BTreeMap::new()); let pub_key_map = RwLock::new(BTreeMap::new());
@ -94,9 +115,7 @@ pub(crate) async fn send_transaction_message_route(
.event_handler .event_handler
.fetch_required_signing_keys(parsed_pdus.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) .fetch_required_signing_keys(parsed_pdus.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map)
.await .await
.unwrap_or_else(|e| { .unwrap_or_else(|e| warn!("Could not fetch all signatures for PDUs from {origin}: {e:?}"));
warn!("Could not fetch all signatures for PDUs from {origin}: {:?}", e);
});
debug!( debug!(
elapsed = ?txn_start_time.elapsed(), elapsed = ?txn_start_time.elapsed(),
@ -133,25 +152,49 @@ pub(crate) async fn send_transaction_message_route(
for pdu in &resolved_map { for pdu in &resolved_map {
if let Err(e) = pdu.1 { if let Err(e) = pdu.1 {
if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) { if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) {
warn!("Incoming PDU failed {:?}", pdu); warn!("Incoming PDU failed {pdu:?}");
} }
} }
} }
Ok(resolved_map)
}
async fn handle_edus(
client: &IpAddr, body: &Ruma<send_transaction_message::v1::Request>, origin: &ServerName,
) -> Result<()> {
for edu in body for edu in body
.edus .edus
.iter() .iter()
.filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok()) .filter_map(|edu| serde_json::from_str::<Edu>(edu.json().get()).ok())
{ {
match edu { match edu {
Edu::Presence(presence) => { Edu::Presence(presence) => handle_edu_presence(client, origin, presence).await?,
Edu::Receipt(receipt) => handle_edu_receipt(client, origin, receipt).await?,
Edu::Typing(typing) => handle_edu_typing(client, origin, typing).await?,
Edu::DeviceListUpdate(content) => handle_edu_device_list_update(client, origin, content).await?,
Edu::DirectToDevice(content) => handle_edu_direct_to_device(client, origin, content).await?,
Edu::SigningKeyUpdate(content) => handle_edu_signing_key_update(client, origin, content).await?,
Edu::_Custom(ref _custom) => {
debug_warn!(?body.edus, "received custom/unknown EDU");
},
}
}
Ok(())
}
async fn handle_edu_presence(_client: &IpAddr, origin: &ServerName, presence: PresenceContent) -> Result<()> {
if !services().globals.allow_incoming_presence() { if !services().globals.allow_incoming_presence() {
continue; return Ok(());
} }
for update in presence.push { for update in presence.push {
if update.user_id.server_name() != origin { if update.user_id.server_name() != origin {
debug_warn!(%update.user_id, %origin, "received presence EDU for user not belonging to origin"); debug_warn!(
%update.user_id, %origin,
"received presence EDU for user not belonging to origin"
);
continue; continue;
} }
@ -163,10 +206,13 @@ pub(crate) async fn send_transaction_message_route(
update.status_msg.clone(), update.status_msg.clone(),
)?; )?;
} }
},
Edu::Receipt(receipt) => { Ok(())
}
async fn handle_edu_receipt(_client: &IpAddr, origin: &ServerName, receipt: ReceiptContent) -> Result<()> {
if !services().globals.allow_incoming_read_receipts() { if !services().globals.allow_incoming_read_receipts() {
continue; return Ok(());
} }
for (room_id, room_updates) in receipt.receipts { for (room_id, room_updates) in receipt.receipts {
@ -176,13 +222,19 @@ pub(crate) async fn send_transaction_message_route(
.acl_check(origin, &room_id) .acl_check(origin, &room_id)
.is_err() .is_err()
{ {
debug_warn!(%origin, %room_id, "received read receipt EDU from ACL'd server"); debug_warn!(
%origin, %room_id,
"received read receipt EDU from ACL'd server"
);
continue; continue;
} }
for (user_id, user_updates) in room_updates.read { for (user_id, user_updates) in room_updates.read {
if user_id.server_name() != origin { if user_id.server_name() != origin {
debug_warn!(%user_id, %origin, "received read receipt EDU for user not belonging to origin"); debug_warn!(
%user_id, %origin,
"received read receipt EDU for user not belonging to origin"
);
continue; continue;
} }
@ -195,11 +247,8 @@ pub(crate) async fn send_transaction_message_route(
{ {
for event_id in &user_updates.event_ids { for event_id in &user_updates.event_ids {
let user_receipts = BTreeMap::from([(user_id.clone(), user_updates.data.clone())]); let user_receipts = BTreeMap::from([(user_id.clone(), user_updates.data.clone())]);
let receipts = BTreeMap::from([(ReceiptType::Read, user_receipts)]); let receipts = BTreeMap::from([(ReceiptType::Read, user_receipts)]);
let receipt_content = BTreeMap::from([(event_id.to_owned(), receipts)]); let receipt_content = BTreeMap::from([(event_id.to_owned(), receipts)]);
let event = ReceiptEvent { let event = ReceiptEvent {
content: ReceiptEventContent(receipt_content), content: ReceiptEventContent(receipt_content),
room_id: room_id.clone(), room_id: room_id.clone(),
@ -211,20 +260,29 @@ pub(crate) async fn send_transaction_message_route(
.readreceipt_update(&user_id, &room_id, &event)?; .readreceipt_update(&user_id, &room_id, &event)?;
} }
} else { } else {
debug_warn!(%user_id, %room_id, %origin, "received read receipt EDU from server who does not have a single member from their server in the room"); debug_warn!(
%user_id, %room_id, %origin,
"received read receipt EDU from server who does not have a member in the room",
);
continue; continue;
} }
} }
} }
},
Edu::Typing(typing) => { Ok(())
}
async fn handle_edu_typing(_client: &IpAddr, origin: &ServerName, typing: TypingContent) -> Result<()> {
if !services().globals.config.allow_incoming_typing { if !services().globals.config.allow_incoming_typing {
continue; return Ok(());
} }
if typing.user_id.server_name() != origin { if typing.user_id.server_name() != origin {
debug_warn!(%typing.user_id, %origin, "received typing EDU for user not belonging to origin"); debug_warn!(
continue; %typing.user_id, %origin,
"received typing EDU for user not belonging to origin"
);
return Ok(());
} }
if services() if services()
@ -233,8 +291,11 @@ pub(crate) async fn send_transaction_message_route(
.acl_check(typing.user_id.server_name(), &typing.room_id) .acl_check(typing.user_id.server_name(), &typing.room_id)
.is_err() .is_err()
{ {
debug_warn!(%typing.user_id, %typing.room_id, %origin, "received typing EDU for ACL'd user's server"); debug_warn!(
continue; %typing.user_id, %typing.room_id, %origin,
"received typing EDU for ACL'd user's server"
);
return Ok(());
} }
if services() if services()
@ -263,30 +324,53 @@ pub(crate) async fn send_transaction_message_route(
.await?; .await?;
} }
} else { } else {
debug_warn!(%typing.user_id, %typing.room_id, %origin, "received typing EDU for user not in room"); debug_warn!(
continue; %typing.user_id, %typing.room_id, %origin,
"received typing EDU for user not in room"
);
return Ok(());
} }
},
Edu::DeviceListUpdate(DeviceListUpdateContent { Ok(())
}
async fn handle_edu_device_list_update(
_client: &IpAddr, origin: &ServerName, content: DeviceListUpdateContent,
) -> Result<()> {
let DeviceListUpdateContent {
user_id, user_id,
.. ..
}) => { } = content;
if user_id.server_name() != origin { if user_id.server_name() != origin {
debug_warn!(%user_id, %origin, "received device list update EDU for user not belonging to origin"); debug_warn!(
continue; %user_id, %origin,
"received device list update EDU for user not belonging to origin"
);
return Ok(());
} }
services().users.mark_device_key_update(&user_id)?; services().users.mark_device_key_update(&user_id)?;
},
Edu::DirectToDevice(DirectDeviceContent { Ok(())
}
async fn handle_edu_direct_to_device(
_client: &IpAddr, origin: &ServerName, content: DirectDeviceContent,
) -> Result<()> {
let DirectDeviceContent {
sender, sender,
ev_type, ev_type,
message_id, message_id,
messages, messages,
}) => { } = content;
if sender.server_name() != origin { if sender.server_name() != origin {
debug_warn!(%sender, %origin, "received direct to device EDU for user not belonging to origin"); debug_warn!(
continue; %sender, %origin,
"received direct to device EDU for user not belonging to origin"
);
return Ok(());
} }
// Check if this is a new transaction id // Check if this is a new transaction id
@ -295,7 +379,7 @@ pub(crate) async fn send_transaction_message_route(
.existing_txnid(&sender, None, &message_id)? .existing_txnid(&sender, None, &message_id)?
.is_some() .is_some()
{ {
continue; return Ok(());
} }
for (target_user_id, map) in &messages { for (target_user_id, map) in &messages {
@ -321,9 +405,9 @@ pub(crate) async fn send_transaction_message_route(
target_user_id, target_user_id,
&target_device_id?, &target_device_id?,
&ev_type.to_string(), &ev_type.to_string(),
event.deserialize_as().map_err(|_| { event
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") .deserialize_as()
})?, .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
)?; )?;
} }
}, },
@ -335,15 +419,25 @@ pub(crate) async fn send_transaction_message_route(
services() services()
.transaction_ids .transaction_ids
.add_txnid(&sender, None, &message_id, &[])?; .add_txnid(&sender, None, &message_id, &[])?;
},
Edu::SigningKeyUpdate(SigningKeyUpdateContent { Ok(())
}
async fn handle_edu_signing_key_update(
_client: &IpAddr, origin: &ServerName, content: SigningKeyUpdateContent,
) -> Result<()> {
let SigningKeyUpdateContent {
user_id, user_id,
master_key, master_key,
self_signing_key, self_signing_key,
}) => { } = content;
if user_id.server_name() != origin { if user_id.server_name() != origin {
debug_warn!(%user_id, %origin, "received signing key update EDU from server that does not belong to user's server"); debug_warn!(
continue; %user_id, %origin,
"received signing key update EDU from server that does not belong to user's server"
);
return Ok(());
} }
if let Some(master_key) = master_key { if let Some(master_key) = master_key {
@ -351,26 +445,6 @@ pub(crate) async fn send_transaction_message_route(
.users .users
.add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; .add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?;
} }
},
Edu::_Custom(ref _custom) => {
debug_warn!(?edu, "received custom/unknown EDU");
},
}
}
debug!( Ok(())
pdus = ?body.pdus.len(),
edus = ?body.edus.len(),
elapsed = ?txn_start_time.elapsed(),
id = ?body.transaction_id,
origin =?body.origin,
"Finished txn",
);
Ok(send_transaction_message::v1::Response {
pdus: resolved_map
.into_iter()
.map(|(e, r)| (e, r.map_err(|e| e.sanitized_error())))
.collect(),
})
} }