Refactor server_keys service/interface and related callsites

Signed-off-by: Jason Volk <jason@zemos.net>
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
Jason Volk 2024-10-11 18:57:59 +00:00 committed by strawberry
parent d82ea331cf
commit c0939c3e9a
30 changed files with 1025 additions and 1378 deletions

View file

@ -28,12 +28,10 @@ use ruma::{
StateEventType, TimelineEventType,
},
int,
serde::Base64,
state_res::{self, EventTypeExt, RoomVersion, StateMap},
uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, RoomId, RoomVersionId,
ServerName, UserId,
uint, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId,
RoomId, RoomVersionId, ServerName, UserId,
};
use tokio::sync::RwLock;
use super::state_compressor::CompressedStateEvent;
use crate::{globals, rooms, sending, server_keys, Dep};
@ -129,11 +127,10 @@ impl Service {
/// 13. Use state resolution to find new room state
/// 14. Check if the event passes auth based on the "current state" of the
/// room, if not soft fail it
#[tracing::instrument(skip(self, origin, value, is_timeline_event, pub_key_map), name = "pdu")]
#[tracing::instrument(skip(self, origin, value, is_timeline_event), name = "pdu")]
pub async fn handle_incoming_pdu<'a>(
&self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId,
value: BTreeMap<String, CanonicalJsonValue>, is_timeline_event: bool,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>> {
// 1. Skip the PDU if we already have it as a timeline event
if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await {
@ -177,7 +174,7 @@ impl Service {
let first_pdu_in_room = self.services.timeline.first_pdu_in_room(room_id).await?;
let (incoming_pdu, val) = self
.handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map)
.handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false)
.boxed()
.await?;
@ -200,7 +197,6 @@ impl Service {
&create_event,
room_id,
&room_version_id,
pub_key_map,
incoming_pdu.prev_events.clone(),
)
.await?;
@ -212,7 +208,6 @@ impl Service {
origin,
event_id,
room_id,
pub_key_map,
&mut eventid_info,
&create_event,
&first_pdu_in_room,
@ -250,7 +245,7 @@ impl Service {
.insert(room_id.to_owned(), (event_id.to_owned(), start_time));
let r = self
.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map)
.upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id)
.await;
self.federation_handletime
@ -264,12 +259,11 @@ impl Service {
#[allow(clippy::type_complexity)]
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(
skip(self, origin, event_id, room_id, pub_key_map, eventid_info, create_event, first_pdu_in_room),
skip(self, origin, event_id, room_id, eventid_info, create_event, first_pdu_in_room),
name = "prev"
)]
pub async fn handle_prev_pdu<'a>(
&self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
eventid_info: &mut HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>,
create_event: &Arc<PduEvent>, first_pdu_in_room: &Arc<PduEvent>, prev_id: &EventId,
) -> Result<()> {
@ -318,7 +312,7 @@ impl Service {
.expect("locked")
.insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time));
self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id, pub_key_map)
self.upgrade_outlier_to_timeline_pdu(pdu, json, create_event, origin, room_id)
.await?;
self.federation_handletime
@ -338,8 +332,7 @@ impl Service {
#[allow(clippy::too_many_arguments)]
async fn handle_outlier_pdu<'a>(
&self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId,
mut value: BTreeMap<String, CanonicalJsonValue>, auth_events_known: bool,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
mut value: CanonicalJsonObject, auth_events_known: bool,
) -> Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)> {
// 1. Remove unsigned field
value.remove("unsigned");
@ -349,14 +342,13 @@ impl Service {
// 2. Check signatures, otherwise drop
// 3. check content hash, redact if doesn't match
let room_version_id = Self::get_room_version_id(create_event)?;
let guard = pub_key_map.read().await;
let mut val = match ruma::signatures::verify_event(&guard, &value, &room_version_id) {
Err(e) => {
// Drop
warn!("Dropping bad event {event_id}: {e}");
return Err!(Request(InvalidParam("Signature verification failed")));
},
let mut val = match self
.services
.server_keys
.verify_event(&value, Some(&room_version_id))
.await
{
Ok(ruma::signatures::Verified::All) => value,
Ok(ruma::signatures::Verified::Signatures) => {
// Redact
debug_info!("Calculated hash does not match (redaction): {event_id}");
@ -371,11 +363,13 @@ impl Service {
obj
},
Ok(ruma::signatures::Verified::All) => value,
Err(e) => {
return Err!(Request(InvalidParam(debug_error!(
"Signature verification failed for {event_id}: {e}"
))))
},
};
drop(guard);
// Now that we have checked the signature and hashes we can add the eventID and
// convert to our PduEvent type
val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
@ -404,7 +398,6 @@ impl Service {
create_event,
room_id,
&room_version_id,
pub_key_map,
),
)
.await;
@ -487,7 +480,7 @@ impl Service {
pub async fn upgrade_outlier_to_timeline_pdu(
&self, incoming_pdu: Arc<PduEvent>, val: BTreeMap<String, CanonicalJsonValue>, create_event: &PduEvent,
origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
origin: &ServerName, room_id: &RoomId,
) -> Result<Option<Vec<u8>>> {
// Skip the PDU if we already have it as a timeline event
if let Ok(pduid) = self
@ -526,14 +519,7 @@ impl Service {
if state_at_incoming_event.is_none() {
state_at_incoming_event = self
.fetch_state(
origin,
create_event,
room_id,
&room_version_id,
pub_key_map,
&incoming_pdu.event_id,
)
.fetch_state(origin, create_event, room_id, &room_version_id, &incoming_pdu.event_id)
.await?;
}
@ -1021,10 +1007,10 @@ impl Service {
/// Call /state_ids to find out what the state at this pdu is. We trust the
/// server's response to some extend (sic), but we still do a lot of checks
/// on the events
#[tracing::instrument(skip(self, pub_key_map, create_event, room_version_id))]
#[tracing::instrument(skip(self, create_event, room_version_id))]
async fn fetch_state(
&self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, event_id: &EventId,
event_id: &EventId,
) -> Result<Option<HashMap<u64, Arc<EventId>>>> {
debug!("Fetching state ids");
let res = self
@ -1048,7 +1034,7 @@ impl Service {
.collect::<Vec<_>>();
let state_vec = self
.fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id, pub_key_map)
.fetch_and_handle_outliers(origin, &collect, create_event, room_id, room_version_id)
.boxed()
.await;
@ -1102,7 +1088,7 @@ impl Service {
/// d. TODO: Ask other servers over federation?
pub async fn fetch_and_handle_outliers<'a>(
&self, origin: &'a ServerName, events: &'a [Arc<EventId>], create_event: &'a PduEvent, room_id: &'a RoomId,
room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
room_version_id: &'a RoomVersionId,
) -> Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)> {
let back_off = |id| match self
.services
@ -1222,22 +1208,6 @@ impl Service {
events_with_auth_events.push((id, None, events_in_reverse_order));
}
// We go through all the signatures we see on the PDUs and their unresolved
// dependencies and fetch the corresponding signing keys
self.services
.server_keys
.fetch_required_signing_keys(
events_with_auth_events
.iter()
.flat_map(|(_id, _local_pdu, events)| events)
.map(|(_event_id, event)| event),
pub_key_map,
)
.await
.unwrap_or_else(|e| {
warn!("Could not fetch all signatures for PDUs from {origin}: {e:?}");
});
let mut pdus = Vec::with_capacity(events_with_auth_events.len());
for (id, local_pdu, events_in_reverse_order) in events_with_auth_events {
// a. Look in the main timeline (pduid_pdu tree)
@ -1266,16 +1236,8 @@ impl Service {
}
}
match Box::pin(self.handle_outlier_pdu(
origin,
create_event,
&next_id,
room_id,
value.clone(),
true,
pub_key_map,
))
.await
match Box::pin(self.handle_outlier_pdu(origin, create_event, &next_id, room_id, value.clone(), true))
.await
{
Ok((pdu, json)) => {
if next_id == *id {
@ -1296,7 +1258,7 @@ impl Service {
#[tracing::instrument(skip_all)]
async fn fetch_prev(
&self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, initial_set: Vec<Arc<EventId>>,
initial_set: Vec<Arc<EventId>>,
) -> Result<(
Vec<Arc<EventId>>,
HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>,
@ -1311,14 +1273,7 @@ impl Service {
while let Some(prev_event_id) = todo_outlier_stack.pop() {
if let Some((pdu, mut json_opt)) = self
.fetch_and_handle_outliers(
origin,
&[prev_event_id.clone()],
create_event,
room_id,
room_version_id,
pub_key_map,
)
.fetch_and_handle_outliers(origin, &[prev_event_id.clone()], create_event, room_id, room_version_id)
.boxed()
.await
.pop()