diff --git a/src/service/server_keys/acquire.rs b/src/service/server_keys/acquire.rs index 25b676b8..cdaf28b4 100644 --- a/src/service/server_keys/acquire.rs +++ b/src/service/server_keys/acquire.rs @@ -1,15 +1,17 @@ use std::{ borrow::Borrow, collections::{BTreeMap, BTreeSet}, + time::Duration, }; -use conduit::{debug, debug_warn, error, implement, result::FlatOk, warn}; +use conduit::{debug, debug_error, debug_warn, error, implement, result::FlatOk, trace, warn}; use futures::{stream::FuturesUnordered, StreamExt}; use ruma::{ api::federation::discovery::ServerSigningKeys, serde::Raw, CanonicalJsonObject, OwnedServerName, OwnedServerSigningKeyId, ServerName, ServerSigningKeyId, }; use serde_json::value::RawValue as RawJsonValue; +use tokio::time::{timeout_at, Instant}; use super::key_exists; @@ -136,8 +138,12 @@ async fn acquire_origins(&self, batch: I) -> Batch where I: Iterator)> + Send, { + let timeout = Instant::now() + .checked_add(Duration::from_secs(45)) + .expect("timeout overflows"); + let mut requests: FuturesUnordered<_> = batch - .map(|(origin, key_ids)| self.acquire_origin(origin, key_ids)) + .map(|(origin, key_ids)| self.acquire_origin(origin, key_ids, timeout)) .collect(); let mut missing = Batch::new(); @@ -152,11 +158,22 @@ where #[implement(super::Service)] async fn acquire_origin( - &self, origin: OwnedServerName, mut key_ids: Vec, + &self, origin: OwnedServerName, mut key_ids: Vec, timeout: Instant, ) -> (OwnedServerName, Vec) { - if let Ok(server_keys) = self.server_request(&origin).await { - self.add_signing_keys(server_keys.clone()).await; - key_ids.retain(|key_id| !key_exists(&server_keys, key_id)); + match timeout_at(timeout, self.server_request(&origin)).await { + Err(e) => debug_warn!(?origin, "timed out: {e}"), + Ok(Err(e)) => debug_error!(?origin, "{e}"), + Ok(Ok(server_keys)) => { + trace!( + %origin, + ?key_ids, + ?server_keys, + "received server_keys" + ); + + self.add_signing_keys(server_keys.clone()).await; + key_ids.retain(|key_id| !key_exists(&server_keys, key_id)); + }, } (origin, key_ids)