diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index afe5a1e5..1a36936d 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -18,7 +18,6 @@ use super::{ pub(crate) struct ActualDest { pub(crate) dest: FedDest, pub(crate) host: String, - pub(crate) cached: bool, } impl ActualDest { @@ -29,10 +28,10 @@ impl ActualDest { impl super::Service { #[tracing::instrument(skip_all, level = "debug", name = "resolve")] pub(crate) async fn get_actual_dest(&self, server_name: &ServerName) -> Result { - let (CachedDest { dest, host, .. }, cached) = + let (CachedDest { dest, host, .. }, _cached) = self.lookup_actual_dest(server_name).await?; - Ok(ActualDest { dest, host, cached }) + Ok(ActualDest { dest, host }) } pub(crate) async fn lookup_actual_dest( @@ -49,6 +48,7 @@ impl super::Service { } self.resolve_actual_dest(server_name, true) + .inspect_ok(|result| self.cache.set_destination(server_name, result)) .map_ok(|result| (result, false)) .boxed() .await @@ -334,7 +334,7 @@ impl super::Service { debug_info!("{overname:?} overriden by {hostname:?}"); } - self.cache.set_override(overname, CachedOverride { + self.cache.set_override(overname, &CachedOverride { ips: override_ip.into_iter().take(MAX_IPS).collect(), port, expire: CachedOverride::default_expire(), diff --git a/src/service/resolver/cache.rs b/src/service/resolver/cache.rs index 11e6c9bd..657718b3 100644 --- a/src/service/resolver/cache.rs +++ b/src/service/resolver/cache.rs @@ -45,12 +45,12 @@ impl Cache { } #[implement(Cache)] -pub fn set_destination(&self, name: &ServerName, dest: CachedDest) { +pub fn set_destination(&self, name: &ServerName, dest: &CachedDest) { self.destinations.raw_put(name, Cbor(dest)); } #[implement(Cache)] -pub fn set_override(&self, name: &str, over: CachedOverride) { +pub fn set_override(&self, name: &str, over: &CachedOverride) { self.overrides.raw_put(name, Cbor(over)); } diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index 831a1dd8..c8a64f3c 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -18,10 +18,7 @@ use ruma::{ CanonicalJsonObject, CanonicalJsonValue, ServerName, ServerSigningKeyId, }; -use crate::{ - resolver, - resolver::{actual::ActualDest, cache::CachedDest}, -}; +use crate::resolver::actual::ActualDest; impl super::Service { #[tracing::instrument( @@ -73,16 +70,7 @@ impl super::Service { debug!(?method, ?url, "Sending request"); match client.execute(request).await { - | Ok(response) => - handle_response::( - &self.services.resolver, - dest, - actual, - &method, - &url, - response, - ) - .await, + | Ok(response) => handle_response::(dest, actual, &method, &url, response).await, | Err(error) => Err(handle_error(actual, &method, &url, error).expect_err("always returns error")), } @@ -111,7 +99,6 @@ impl super::Service { } async fn handle_response( - resolver: &resolver::Service, dest: &ServerName, actual: &ActualDest, method: &Method, @@ -122,17 +109,9 @@ where T: OutgoingRequest + Send, { let response = into_http_response(dest, actual, method, url, response).await?; - let result = T::IncomingResponse::try_from_http_response(response); - if result.is_ok() && !actual.cached { - resolver.cache.set_destination(dest, CachedDest { - dest: actual.dest.clone(), - host: actual.host.clone(), - expire: CachedDest::default_expire(), - }); - } - - result.map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}"))) + T::IncomingResponse::try_from_http_response(response) + .map_err(|e| err!(BadServerResponse("Server returned bad 200 response: {e:?}"))) } async fn into_http_response(