diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 43fcf8f7..da119db5 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,25 +1,14 @@ -/// An async function that can recursively call itself. -type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; - use std::{ collections::{hash_map, HashSet}, pin::Pin, - time::{Duration, Instant, SystemTime}, + time::{Duration, Instant}, }; -use futures_util::{stream::FuturesUnordered, Future, StreamExt}; +use futures_util::Future; use ruma::{ api::{ client::error::ErrorKind, - federation::{ - discovery::{ - get_remote_server_keys, - get_remote_server_keys_batch::{self, v2::QueryCriteria}, - get_server_keys, - }, - event::{get_event, get_room_state_ids}, - membership::create_join_event, - }, + federation::event::{get_event, get_room_state_ids}, }, events::{ room::{create::RoomCreateEventContent, server_acl::RoomServerAclEventContent}, @@ -28,11 +17,9 @@ use ruma::{ int, serde::Base64, state_res::{self, RoomVersion, StateMap}, - uint, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedServerName, - OwnedServerSigningKeyId, RoomId, RoomVersionId, ServerName, + uint, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, RoomId, RoomVersionId, ServerName, }; -use serde_json::value::RawValue as RawJsonValue; -use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore}; +use tokio::sync::RwLock; use tracing::{debug, error, info, trace, warn}; use super::state_compressor::CompressedStateEvent; @@ -41,13 +28,17 @@ use crate::{ services, Error, PduEvent, }; +pub mod signing_keys; +pub struct Service; + +// We use some AsyncRecursiveType hacks here so we can call async funtion +// recursively. +type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; type AsyncRecursiveCanonicalJsonVec<'a> = AsyncRecursiveType<'a, Vec<(Arc, Option>)>>; type AsyncRecursiveCanonicalJsonResult<'a> = AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>>; -pub struct Service; - impl Service { /// When receiving an event one needs to: /// 0. Check the server is in the room @@ -76,8 +67,6 @@ impl Service { /// 13. Use state resolution to find new room state /// 14. Check if the event passes auth based on the "current state" of the /// room, if not soft fail it - // We use some AsyncRecursiveType hacks here so we can call this async funtion - // recursively #[tracing::instrument(skip(self, origin, value, is_timeline_event, pub_key_map), name = "pdu")] pub(crate) async fn handle_incoming_pdu<'a>( &self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, @@ -1269,348 +1258,6 @@ impl Service { Ok((sorted, eventid_info)) } - pub(crate) async fn fetch_required_signing_keys<'a, E>( - &'a self, events: E, pub_key_map: &RwLock>>, - ) -> Result<()> - where - E: IntoIterator>, - { - let mut server_key_ids = HashMap::new(); - - for event in events { - debug!("Fetching keys for event: {event:?}"); - for (signature_server, signature) in event - .get("signatures") - .ok_or(Error::BadServerResponse("No signatures in server response pdu."))? - .as_object() - .ok_or(Error::BadServerResponse("Invalid signatures object in server response pdu."))? - { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; - - for signature_id in signature_object.keys() { - server_key_ids - .entry(signature_server.clone()) - .or_insert_with(HashSet::new) - .insert(signature_id.clone()); - } - } - } - - if server_key_ids.is_empty() { - // Nothing to do, can exit early - trace!("server_key_ids is empty, not fetching any keys"); - return Ok(()); - } - - debug!( - "Fetch keys for {}", - server_key_ids - .keys() - .cloned() - .collect::>() - .join(", ") - ); - - let mut server_keys: FuturesUnordered<_> = server_key_ids - .into_iter() - .map(|(signature_server, signature_ids)| async { - let fetch_res = self - .fetch_signing_keys_for_server( - signature_server.as_str().try_into().map_err(|e| { - info!("Invalid servername in signatures of server response pdu: {e}"); - ( - signature_server.clone(), - Error::BadServerResponse("Invalid servername in signatures of server response pdu."), - ) - })?, - signature_ids.into_iter().collect(), // HashSet to Vec - ) - .await; - - match fetch_res { - Ok(keys) => Ok((signature_server, keys)), - Err(e) => { - warn!("Signature verification failed: Could not fetch signing key for {signature_server}: {e}",); - Err((signature_server, e)) - }, - } - }) - .collect(); - - while let Some(fetch_res) = server_keys.next().await { - match fetch_res { - Ok((signature_server, keys)) => { - pub_key_map - .write() - .await - .insert(signature_server.clone(), keys); - }, - Err((signature_server, e)) => { - warn!("Failed to fetch keys for {}: {:?}", signature_server, e); - }, - } - } - - Ok(()) - } - - // Gets a list of servers for which we don't have the signing key yet. We go - // over the PDUs and either cache the key or add it to the list that needs to be - // retrieved. - async fn get_server_keys_from_cache( - &self, pdu: &RawJsonValue, - servers: &mut BTreeMap>, - room_version: &RoomVersionId, - pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap>>, - ) -> Result<()> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&value, room_version).expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); - - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .await - .get(event_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", event_id); - return Err(Error::BadServerResponse("bad event, still backing off")); - } - } - - let signatures = value - .get("signatures") - .ok_or(Error::BadServerResponse("No signatures in server response pdu."))? - .as_object() - .ok_or(Error::BadServerResponse("Invalid signatures object in server response pdu."))?; - - for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; - - let signature_ids = signature_object.keys().cloned().collect::>(); - - let contains_all_ids = - |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); - - let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|e| { - info!("Invalid servername in signatures of server response pdu: {e}"); - Error::BadServerResponse("Invalid servername in signatures of server response pdu.") - })?; - - if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { - continue; - } - - debug!("Loading signing keys for {}", origin); - - let result: BTreeMap<_, _> = services() - .globals - .signing_keys_for(origin)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if !contains_all_ids(&result) { - debug!("Signing key not loaded for {}", origin); - servers.insert(origin.to_owned(), BTreeMap::new()); - } - - pub_key_map.insert(origin.to_string(), result); - } - - Ok(()) - } - - /// Batch requests homeserver signing keys from trusted notary key servers - /// (`trusted_servers` config option) - async fn batch_request_signing_keys( - &self, mut servers: BTreeMap>, - pub_key_map: &RwLock>>, - ) -> Result<()> { - for server in services().globals.trusted_servers() { - info!("Asking batch signing keys from trusted server {}", server); - match services() - .sending - .send_federation_request( - server, - get_remote_server_keys_batch::v2::Request { - server_keys: servers.clone(), - }, - ) - .await - { - Ok(keys) => { - debug!("Got signing keys: {:?}", keys); - let mut pkm = pub_key_map.write().await; - for k in keys.server_keys { - let k = match k.deserialize() { - Ok(key) => key, - Err(e) => { - warn!("Received error {e} while fetching keys from trusted server {server}"); - warn!("{}", k.into_json()); - continue; - }, - }; - - // TODO: Check signature from trusted server? - servers.remove(&k.server_name); - - let result = services() - .globals - .add_signing_key(&k.server_name, k.clone())? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect::>(); - - pkm.insert(k.server_name.to_string(), result); - } - }, - Err(e) => { - warn!( - "Failed sending batched key request to trusted key server {server} for the remote servers \ - {:?}: {e}", - servers - ); - }, - } - } - - Ok(()) - } - - /// Requests multiple homeserver signing keys from individual servers (not - /// trused notary servers) - async fn request_signing_keys( - &self, servers: BTreeMap>, - pub_key_map: &RwLock>>, - ) -> Result<()> { - info!("Asking individual servers for signing keys: {servers:?}"); - let mut futures: FuturesUnordered<_> = servers - .into_keys() - .map(|server| async move { - ( - services() - .sending - .send_federation_request(&server, get_server_keys::v2::Request::new()) - .await, - server, - ) - }) - .collect(); - - while let Some(result) = futures.next().await { - debug!("Received new Future result"); - if let (Ok(get_keys_response), origin) = result { - info!("Result is from {origin}"); - if let Ok(key) = get_keys_response.server_key.deserialize() { - let result: BTreeMap<_, _> = services() - .globals - .add_signing_key(&origin, key)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - pub_key_map.write().await.insert(origin.to_string(), result); - } - } - debug!("Done handling Future result"); - } - - Ok(()) - } - - pub(crate) async fn fetch_join_signing_keys( - &self, event: &create_join_event::v2::Response, room_version: &RoomVersionId, - pub_key_map: &RwLock>>, - ) -> Result<()> { - let mut servers: BTreeMap> = BTreeMap::new(); - - { - let mut pkm = pub_key_map.write().await; - - // Try to fetch keys, failure is okay - // Servers we couldn't find in the cache will be added to `servers` - for pdu in &event.room_state.state { - _ = self - .get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm) - .await; - } - for pdu in &event.room_state.auth_chain { - _ = self - .get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm) - .await; - } - - drop(pkm); - }; - - if servers.is_empty() { - trace!("We had all keys cached locally, not fetching any keys from remote servers"); - return Ok(()); - } - - if services().globals.query_trusted_key_servers_first() { - info!( - "query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \ - homeserver signing keys." - ); - - self.batch_request_signing_keys(servers.clone(), pub_key_map) - .await?; - - if servers.is_empty() { - info!("Trusted server supplied all signing keys, no more keys to fetch"); - return Ok(()); - } - - info!("Remaining servers left that the notary/trusted servers did not provide: {servers:?}"); - - self.request_signing_keys(servers.clone(), pub_key_map) - .await?; - } else { - info!("query_trusted_key_servers_first is set to false, querying individual homeservers first"); - - self.request_signing_keys(servers.clone(), pub_key_map) - .await?; - - if servers.is_empty() { - info!("Individual homeservers supplied all signing keys, no more keys to fetch"); - return Ok(()); - } - - info!("Remaining servers left the individual homeservers did not provide: {servers:?}"); - - self.batch_request_signing_keys(servers.clone(), pub_key_map) - .await?; - } - - debug!("Search for signing keys done"); - - /*if servers.is_empty() { - warn!("Failed to find homeserver signing keys for the remaining servers: {servers:?}"); - }*/ - - Ok(()) - } - /// Returns Ok if the acl allows the server pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { let acl_event = if let Some(acl) = @@ -1619,16 +1266,16 @@ impl Service { .state_accessor .room_state_get(room_id, &StateEventType::RoomServerAcl, "")? { - debug!("ACL event found: {acl:?}"); + trace!("ACL event found: {acl:?}"); acl } else { - debug!("No ACL event found"); + trace!("No ACL event found"); return Ok(()); }; let acl_event_content: RoomServerAclEventContent = match serde_json::from_str(acl_event.content.get()) { Ok(content) => { - debug!("Found ACL event contents: {content:?}"); + trace!("Found ACL event contents: {content:?}"); content }, Err(e) => { @@ -1644,258 +1291,14 @@ impl Service { } if acl_event_content.is_allowed(server_name) { - debug!("server {server_name} is allowed by ACL"); + trace!("server {server_name} is allowed by ACL"); Ok(()) } else { - info!("Server {} was denied by room ACL in {}", server_name, room_id); + debug!("Server {} was denied by room ACL in {}", server_name, room_id); Err(Error::BadRequest(ErrorKind::forbidden(), "Server was denied by room ACL")) } } - /// Search the DB for the signing keys of the given server, if we don't have - /// them fetch them from the server and save to our DB. - #[tracing::instrument(skip_all)] - pub async fn fetch_signing_keys_for_server( - &self, origin: &ServerName, signature_ids: Vec, - ) -> Result> { - let contains_all_ids = |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); - - let permit = services() - .globals - .servername_ratelimiter - .read() - .await - .get(origin) - .map(|s| Arc::clone(s).acquire_owned()); - - let permit = if let Some(p) = permit { - p - } else { - let mut write = services().globals.servername_ratelimiter.write().await; - let s = Arc::clone( - write - .entry(origin.to_owned()) - .or_insert_with(|| Arc::new(Semaphore::new(1))), - ); - - s.acquire_owned() - } - .await; - - let back_off = |id| async { - match services() - .globals - .bad_signature_ratelimiter - .write() - .await - .entry(id) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - }, - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - } - }; - - if let Some((time, tries)) = services() - .globals - .bad_signature_ratelimiter - .read() - .await - .get(&signature_ids) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {:?}", signature_ids); - return Err(Error::BadServerResponse("bad signature, still backing off")); - } - } - - let mut result: BTreeMap<_, _> = services() - .globals - .signing_keys_for(origin)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if contains_all_ids(&result) { - trace!("We have all homeserver signing keys locally for {origin}, not fetching any remotely"); - return Ok(result); - } - - // i didnt split this out into their own functions because it's relatively small - if services().globals.query_trusted_key_servers_first() { - info!( - "query_trusted_key_servers_first is set to true, querying notary trusted servers first for {origin} \ - keys" - ); - - for server in services().globals.trusted_servers() { - debug!("Asking notary server {server} for {origin}'s signing key"); - if let Some(server_keys) = services() - .sending - .send_federation_request( - server, - get_remote_server_keys::v2::Request::new( - origin.to_owned(), - MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(3600)) - .expect("SystemTime too large"), - ) - .expect("time is valid"), - ), - ) - .await - .ok() - .map(|resp| { - resp.server_keys - .into_iter() - .filter_map(|e| e.deserialize().ok()) - .collect::>() - }) { - debug!("Got signing keys: {:?}", server_keys); - for k in server_keys { - services().globals.add_signing_key(origin, k.clone())?; - result.extend( - k.verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - k.old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - } - - if contains_all_ids(&result) { - return Ok(result); - } - } - } - - debug!("Asking {origin} for their signing keys over federation"); - if let Some(server_key) = services() - .sending - .send_federation_request(origin, get_server_keys::v2::Request::new()) - .await - .ok() - .and_then(|resp| resp.server_key.deserialize().ok()) - { - services() - .globals - .add_signing_key(origin, server_key.clone())?; - - result.extend( - server_key - .verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - server_key - .old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - - if contains_all_ids(&result) { - return Ok(result); - } - } - } else { - info!("query_trusted_key_servers_first is set to false, querying {origin} first"); - - debug!("Asking {origin} for their signing keys over federation"); - if let Some(server_key) = services() - .sending - .send_federation_request(origin, get_server_keys::v2::Request::new()) - .await - .ok() - .and_then(|resp| resp.server_key.deserialize().ok()) - { - services() - .globals - .add_signing_key(origin, server_key.clone())?; - - result.extend( - server_key - .verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - server_key - .old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - - if contains_all_ids(&result) { - return Ok(result); - } - } - - for server in services().globals.trusted_servers() { - debug!("Asking notary server {server} for {origin}'s signing key"); - if let Some(server_keys) = services() - .sending - .send_federation_request( - server, - get_remote_server_keys::v2::Request::new( - origin.to_owned(), - MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(3600)) - .expect("SystemTime too large"), - ) - .expect("time is valid"), - ), - ) - .await - .ok() - .map(|resp| { - resp.server_keys - .into_iter() - .filter_map(|e| e.deserialize().ok()) - .collect::>() - }) { - debug!("Got signing keys: {:?}", server_keys); - for k in server_keys { - services().globals.add_signing_key(origin, k.clone())?; - result.extend( - k.verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - k.old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - } - - if contains_all_ids(&result) { - return Ok(result); - } - } - } - } - - drop(permit); - - back_off(signature_ids).await; - - warn!("Failed to find public key for server: {origin}"); - Err(Error::BadServerResponse("Failed to find public key for server")) - } - fn check_room_id(&self, room_id: &RoomId, pdu: &PduEvent) -> Result<()> { if pdu.room_id != room_id { warn!("Found event from room {} in room {}", pdu.room_id, room_id); diff --git a/src/service/rooms/event_handler/signing_keys.rs b/src/service/rooms/event_handler/signing_keys.rs new file mode 100644 index 00000000..52ad5186 --- /dev/null +++ b/src/service/rooms/event_handler/signing_keys.rs @@ -0,0 +1,615 @@ +use std::{ + collections::{hash_map, HashSet}, + time::{Duration, Instant, SystemTime}, +}; + +use futures_util::{stream::FuturesUnordered, StreamExt}; +use ruma::{ + api::federation::{ + discovery::{ + get_remote_server_keys, + get_remote_server_keys_batch::{self, v2::QueryCriteria}, + get_server_keys, + }, + membership::create_join_event, + }, + serde::Base64, + CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedServerName, + OwnedServerSigningKeyId, RoomVersionId, ServerName, +}; +use serde_json::value::RawValue as RawJsonValue; +use tokio::sync::{RwLock, RwLockWriteGuard, Semaphore}; +use tracing::{debug, error, info, trace, warn}; + +use crate::{ + service::{Arc, BTreeMap, HashMap, Result}, + services, Error, +}; + +impl super::Service { + pub(crate) async fn fetch_required_signing_keys<'a, E>( + &'a self, events: E, pub_key_map: &RwLock>>, + ) -> Result<()> + where + E: IntoIterator>, + { + let mut server_key_ids = HashMap::new(); + + for event in events { + debug!("Fetching keys for event: {event:?}"); + for (signature_server, signature) in event + .get("signatures") + .ok_or(Error::BadServerResponse("No signatures in server response pdu."))? + .as_object() + .ok_or(Error::BadServerResponse("Invalid signatures object in server response pdu."))? + { + let signature_object = signature.as_object().ok_or(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + ))?; + + for signature_id in signature_object.keys() { + server_key_ids + .entry(signature_server.clone()) + .or_insert_with(HashSet::new) + .insert(signature_id.clone()); + } + } + } + + if server_key_ids.is_empty() { + // Nothing to do, can exit early + trace!("server_key_ids is empty, not fetching any keys"); + return Ok(()); + } + + debug!( + "Fetch keys for {}", + server_key_ids + .keys() + .cloned() + .collect::>() + .join(", ") + ); + + let mut server_keys: FuturesUnordered<_> = server_key_ids + .into_iter() + .map(|(signature_server, signature_ids)| async { + let fetch_res = self + .fetch_signing_keys_for_server( + signature_server.as_str().try_into().map_err(|e| { + info!("Invalid servername in signatures of server response pdu: {e}"); + ( + signature_server.clone(), + Error::BadServerResponse("Invalid servername in signatures of server response pdu."), + ) + })?, + signature_ids.into_iter().collect(), // HashSet to Vec + ) + .await; + + match fetch_res { + Ok(keys) => Ok((signature_server, keys)), + Err(e) => { + warn!("Signature verification failed: Could not fetch signing key for {signature_server}: {e}",); + Err((signature_server, e)) + }, + } + }) + .collect(); + + while let Some(fetch_res) = server_keys.next().await { + match fetch_res { + Ok((signature_server, keys)) => { + pub_key_map + .write() + .await + .insert(signature_server.clone(), keys); + }, + Err((signature_server, e)) => { + warn!("Failed to fetch keys for {}: {:?}", signature_server, e); + }, + } + } + + Ok(()) + } + + // Gets a list of servers for which we don't have the signing key yet. We go + // over the PDUs and either cache the key or add it to the list that needs to be + // retrieved. + async fn get_server_keys_from_cache( + &self, pdu: &RawJsonValue, + servers: &mut BTreeMap>, + room_version: &RoomVersionId, + pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap>>, + ) -> Result<()> { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&value, room_version).expect("ruma can calculate reference hashes") + ); + let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); + + if let Some((time, tries)) = services() + .globals + .bad_event_ratelimiter + .read() + .await + .get(event_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {}", event_id); + return Err(Error::BadServerResponse("bad event, still backing off")); + } + } + + let signatures = value + .get("signatures") + .ok_or(Error::BadServerResponse("No signatures in server response pdu."))? + .as_object() + .ok_or(Error::BadServerResponse("Invalid signatures object in server response pdu."))?; + + for (signature_server, signature) in signatures { + let signature_object = signature.as_object().ok_or(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + ))?; + + let signature_ids = signature_object.keys().cloned().collect::>(); + + let contains_all_ids = + |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); + + let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|e| { + info!("Invalid servername in signatures of server response pdu: {e}"); + Error::BadServerResponse("Invalid servername in signatures of server response pdu.") + })?; + + if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { + continue; + } + + debug!("Loading signing keys for {}", origin); + + let result: BTreeMap<_, _> = services() + .globals + .signing_keys_for(origin)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect(); + + if !contains_all_ids(&result) { + debug!("Signing key not loaded for {}", origin); + servers.insert(origin.to_owned(), BTreeMap::new()); + } + + pub_key_map.insert(origin.to_string(), result); + } + + Ok(()) + } + + /// Batch requests homeserver signing keys from trusted notary key servers + /// (`trusted_servers` config option) + async fn batch_request_signing_keys( + &self, mut servers: BTreeMap>, + pub_key_map: &RwLock>>, + ) -> Result<()> { + for server in services().globals.trusted_servers() { + info!("Asking batch signing keys from trusted server {}", server); + match services() + .sending + .send_federation_request( + server, + get_remote_server_keys_batch::v2::Request { + server_keys: servers.clone(), + }, + ) + .await + { + Ok(keys) => { + debug!("Got signing keys: {:?}", keys); + let mut pkm = pub_key_map.write().await; + for k in keys.server_keys { + let k = match k.deserialize() { + Ok(key) => key, + Err(e) => { + warn!("Received error {e} while fetching keys from trusted server {server}"); + warn!("{}", k.into_json()); + continue; + }, + }; + + // TODO: Check signature from trusted server? + servers.remove(&k.server_name); + + let result = services() + .globals + .add_signing_key(&k.server_name, k.clone())? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect::>(); + + pkm.insert(k.server_name.to_string(), result); + } + }, + Err(e) => { + warn!( + "Failed sending batched key request to trusted key server {server} for the remote servers \ + {:?}: {e}", + servers + ); + }, + } + } + + Ok(()) + } + + /// Requests multiple homeserver signing keys from individual servers (not + /// trused notary servers) + async fn request_signing_keys( + &self, servers: BTreeMap>, + pub_key_map: &RwLock>>, + ) -> Result<()> { + info!("Asking individual servers for signing keys: {servers:?}"); + let mut futures: FuturesUnordered<_> = servers + .into_keys() + .map(|server| async move { + ( + services() + .sending + .send_federation_request(&server, get_server_keys::v2::Request::new()) + .await, + server, + ) + }) + .collect(); + + while let Some(result) = futures.next().await { + debug!("Received new Future result"); + if let (Ok(get_keys_response), origin) = result { + info!("Result is from {origin}"); + if let Ok(key) = get_keys_response.server_key.deserialize() { + let result: BTreeMap<_, _> = services() + .globals + .add_signing_key(&origin, key)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect(); + pub_key_map.write().await.insert(origin.to_string(), result); + } + } + debug!("Done handling Future result"); + } + + Ok(()) + } + + pub(crate) async fn fetch_join_signing_keys( + &self, event: &create_join_event::v2::Response, room_version: &RoomVersionId, + pub_key_map: &RwLock>>, + ) -> Result<()> { + let mut servers: BTreeMap> = BTreeMap::new(); + + { + let mut pkm = pub_key_map.write().await; + + // Try to fetch keys, failure is okay + // Servers we couldn't find in the cache will be added to `servers` + for pdu in &event.room_state.state { + _ = self + .get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm) + .await; + } + for pdu in &event.room_state.auth_chain { + _ = self + .get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm) + .await; + } + + drop(pkm); + }; + + if servers.is_empty() { + trace!("We had all keys cached locally, not fetching any keys from remote servers"); + return Ok(()); + } + + if services().globals.query_trusted_key_servers_first() { + info!( + "query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \ + homeserver signing keys." + ); + + self.batch_request_signing_keys(servers.clone(), pub_key_map) + .await?; + + if servers.is_empty() { + info!("Trusted server supplied all signing keys, no more keys to fetch"); + return Ok(()); + } + + info!("Remaining servers left that the notary/trusted servers did not provide: {servers:?}"); + + self.request_signing_keys(servers.clone(), pub_key_map) + .await?; + } else { + info!("query_trusted_key_servers_first is set to false, querying individual homeservers first"); + + self.request_signing_keys(servers.clone(), pub_key_map) + .await?; + + if servers.is_empty() { + info!("Individual homeservers supplied all signing keys, no more keys to fetch"); + return Ok(()); + } + + info!("Remaining servers left the individual homeservers did not provide: {servers:?}"); + + self.batch_request_signing_keys(servers.clone(), pub_key_map) + .await?; + } + + debug!("Search for signing keys done"); + + /*if servers.is_empty() { + warn!("Failed to find homeserver signing keys for the remaining servers: {servers:?}"); + }*/ + + Ok(()) + } + + /// Search the DB for the signing keys of the given server, if we don't have + /// them fetch them from the server and save to our DB. + #[tracing::instrument(skip_all)] + pub async fn fetch_signing_keys_for_server( + &self, origin: &ServerName, signature_ids: Vec, + ) -> Result> { + let contains_all_ids = |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); + + let permit = services() + .globals + .servername_ratelimiter + .read() + .await + .get(origin) + .map(|s| Arc::clone(s).acquire_owned()); + + let permit = if let Some(p) = permit { + p + } else { + let mut write = services().globals.servername_ratelimiter.write().await; + let s = Arc::clone( + write + .entry(origin.to_owned()) + .or_insert_with(|| Arc::new(Semaphore::new(1))), + ); + + s.acquire_owned() + } + .await; + + let back_off = |id| async { + match services() + .globals + .bad_signature_ratelimiter + .write() + .await + .entry(id) + { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + } + }; + + if let Some((time, tries)) = services() + .globals + .bad_signature_ratelimiter + .read() + .await + .get(&signature_ids) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {:?}", signature_ids); + return Err(Error::BadServerResponse("bad signature, still backing off")); + } + } + + let mut result: BTreeMap<_, _> = services() + .globals + .signing_keys_for(origin)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect(); + + if contains_all_ids(&result) { + trace!("We have all homeserver signing keys locally for {origin}, not fetching any remotely"); + return Ok(result); + } + + // i didnt split this out into their own functions because it's relatively small + if services().globals.query_trusted_key_servers_first() { + info!( + "query_trusted_key_servers_first is set to true, querying notary trusted servers first for {origin} \ + keys" + ); + + for server in services().globals.trusted_servers() { + debug!("Asking notary server {server} for {origin}'s signing key"); + if let Some(server_keys) = services() + .sending + .send_federation_request( + server, + get_remote_server_keys::v2::Request::new( + origin.to_owned(), + MilliSecondsSinceUnixEpoch::from_system_time( + SystemTime::now() + .checked_add(Duration::from_secs(3600)) + .expect("SystemTime too large"), + ) + .expect("time is valid"), + ), + ) + .await + .ok() + .map(|resp| { + resp.server_keys + .into_iter() + .filter_map(|e| e.deserialize().ok()) + .collect::>() + }) { + debug!("Got signing keys: {:?}", server_keys); + for k in server_keys { + services().globals.add_signing_key(origin, k.clone())?; + result.extend( + k.verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + k.old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + } + + if contains_all_ids(&result) { + return Ok(result); + } + } + } + + debug!("Asking {origin} for their signing keys over federation"); + if let Some(server_key) = services() + .sending + .send_federation_request(origin, get_server_keys::v2::Request::new()) + .await + .ok() + .and_then(|resp| resp.server_key.deserialize().ok()) + { + services() + .globals + .add_signing_key(origin, server_key.clone())?; + + result.extend( + server_key + .verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + server_key + .old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + + if contains_all_ids(&result) { + return Ok(result); + } + } + } else { + info!("query_trusted_key_servers_first is set to false, querying {origin} first"); + + debug!("Asking {origin} for their signing keys over federation"); + if let Some(server_key) = services() + .sending + .send_federation_request(origin, get_server_keys::v2::Request::new()) + .await + .ok() + .and_then(|resp| resp.server_key.deserialize().ok()) + { + services() + .globals + .add_signing_key(origin, server_key.clone())?; + + result.extend( + server_key + .verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + server_key + .old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + + if contains_all_ids(&result) { + return Ok(result); + } + } + + for server in services().globals.trusted_servers() { + debug!("Asking notary server {server} for {origin}'s signing key"); + if let Some(server_keys) = services() + .sending + .send_federation_request( + server, + get_remote_server_keys::v2::Request::new( + origin.to_owned(), + MilliSecondsSinceUnixEpoch::from_system_time( + SystemTime::now() + .checked_add(Duration::from_secs(3600)) + .expect("SystemTime too large"), + ) + .expect("time is valid"), + ), + ) + .await + .ok() + .map(|resp| { + resp.server_keys + .into_iter() + .filter_map(|e| e.deserialize().ok()) + .collect::>() + }) { + debug!("Got signing keys: {:?}", server_keys); + for k in server_keys { + services().globals.add_signing_key(origin, k.clone())?; + result.extend( + k.verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + result.extend( + k.old_verify_keys + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)), + ); + } + + if contains_all_ids(&result) { + return Ok(result); + } + } + } + } + + drop(permit); + + back_off(signature_ids).await; + + warn!("Failed to find public key for server: {origin}"); + Err(Error::BadServerResponse("Failed to find public key for server")) + } +}