From a5520e8b1bc1c4ddb9090dc9b93ef76899e58d9a Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Wed, 22 Jan 2025 23:07:13 +0000 Subject: [PATCH] fix SRV override loss on cache expiration Signed-off-by: Jason Volk --- src/admin/query/resolver.rs | 10 +++--- src/service/resolver/actual.rs | 62 ++++++++++++++++++---------------- src/service/resolver/cache.rs | 9 ++--- src/service/resolver/dns.rs | 22 ++++++++++-- 4 files changed, 63 insertions(+), 40 deletions(-) diff --git a/src/admin/query/resolver.rs b/src/admin/query/resolver.rs index 0b6da6fd..08b5d171 100644 --- a/src/admin/query/resolver.rs +++ b/src/admin/query/resolver.rs @@ -51,12 +51,14 @@ async fn destinations_cache( async fn overrides_cache(&self, server_name: Option) -> Result { use service::resolver::cache::CachedOverride; - writeln!(self, "| Server Name | IP | Port | Expires |").await?; - writeln!(self, "| ----------- | --- | ----:| ------- |").await?; + writeln!(self, "| Server Name | IP | Port | Expires | Overriding |").await?; + writeln!(self, "| ----------- | --- | ----:| ------- | ---------- |").await?; let mut overrides = self.services.resolver.cache.overrides().boxed(); - while let Some((name, CachedOverride { ips, port, expire })) = overrides.next().await { + while let Some((name, CachedOverride { ips, port, expire, overriding })) = + overrides.next().await + { if let Some(server_name) = server_name.as_ref() { if name != server_name { continue; @@ -64,7 +66,7 @@ async fn overrides_cache(&self, server_name: Option) -> Result Result { debug!("2: Hostname with included port"); let (host, port) = dest.as_str().split_at(pos); - self.conditional_query_and_cache_override( - host, - host, - port.parse::().unwrap_or(8448), - cache, - ) - .await?; + self.conditional_query_and_cache(host, port.parse::().unwrap_or(8448), cache) + .await?; Ok(FedDest::Named( host.to_owned(), @@ -163,13 +158,8 @@ impl super::Service { ) -> Result { debug!("3.2: Hostname with port in .well-known file"); let (host, port) = delegated.split_at(pos); - self.conditional_query_and_cache_override( - host, - host, - port.parse::().unwrap_or(8448), - cache, - ) - .await?; + self.conditional_query_and_cache(host, port.parse::().unwrap_or(8448), cache) + .await?; Ok(FedDest::Named( host.to_owned(), @@ -208,7 +198,7 @@ impl super::Service { async fn actual_dest_3_4(&self, cache: bool, delegated: String) -> Result { debug!("3.4: No SRV records, just use the hostname from .well-known"); - self.conditional_query_and_cache_override(&delegated, &delegated, 8448, cache) + self.conditional_query_and_cache(&delegated, 8448, cache) .await?; Ok(add_port_to_hostname(&delegated)) } @@ -243,7 +233,7 @@ impl super::Service { async fn actual_dest_5(&self, dest: &ServerName, cache: bool) -> Result { debug!("5: No SRV record found"); - self.conditional_query_and_cache_override(dest.as_str(), dest.as_str(), 8448, cache) + self.conditional_query_and_cache(dest.as_str(), 8448, cache) .await?; Ok(add_port_to_hostname(dest.as_str())) @@ -251,9 +241,7 @@ impl super::Service { #[tracing::instrument(skip_all, name = "well-known")] async fn request_well_known(&self, dest: &str) -> Result> { - if !self.cache.has_override(dest).await { - self.query_and_cache_override(dest, dest, 8448).await?; - } + self.conditional_query_and_cache(dest, 8448, true).await?; self.services.server.check_running()?; trace!("Requesting well known for {dest}"); @@ -301,6 +289,17 @@ impl super::Service { Ok(Some(m_server.to_owned())) } + #[inline] + async fn conditional_query_and_cache( + &self, + hostname: &str, + port: u16, + cache: bool, + ) -> Result { + self.conditional_query_and_cache_override(hostname, hostname, port, cache) + .await + } + #[inline] async fn conditional_query_and_cache_override( &self, @@ -308,13 +307,17 @@ impl super::Service { hostname: &str, port: u16, cache: bool, - ) -> Result<()> { - if cache { - self.query_and_cache_override(overname, hostname, port) - .await - } else { - Ok(()) + ) -> Result { + if !cache { + return Ok(()); } + + if self.cache.has_override(overname).await { + return Ok(()); + } + + self.query_and_cache_override(overname, hostname, port) + .await } #[tracing::instrument(skip(self, overname, port), name = "ip")] @@ -323,21 +326,20 @@ impl super::Service { overname: &'_ str, hostname: &'_ str, port: u16, - ) -> Result<()> { + ) -> Result { self.services.server.check_running()?; debug!("querying IP for {overname:?} ({hostname:?}:{port})"); match self.resolver.resolver.lookup_ip(hostname.to_owned()).await { | Err(e) => Self::handle_resolve_error(&e, hostname), | Ok(override_ip) => { - if hostname != overname { - debug_info!("{overname:?} overriden by {hostname:?}"); - } - self.cache.set_override(overname, &CachedOverride { ips: override_ip.into_iter().take(MAX_IPS).collect(), port, expire: CachedOverride::default_expire(), + overriding: (hostname != overname) + .then_some(hostname.into()) + .inspect(|_| debug_info!("{overname:?} overriden by {hostname:?}")), }); Ok(()) diff --git a/src/service/resolver/cache.rs b/src/service/resolver/cache.rs index e64878d4..22a92865 100644 --- a/src/service/resolver/cache.rs +++ b/src/service/resolver/cache.rs @@ -30,6 +30,7 @@ pub struct CachedOverride { pub ips: IpAddrs, pub port: u16, pub expire: SystemTime, + pub overriding: Option, } pub type IpAddrs = ArrayVec; @@ -63,7 +64,10 @@ pub async fn has_destination(&self, destination: &ServerName) -> bool { #[implement(Cache)] #[must_use] pub async fn has_override(&self, destination: &str) -> bool { - self.get_override(destination).await.is_ok() + self.get_override(destination) + .await + .iter() + .any(CachedOverride::valid) } #[implement(Cache)] @@ -85,9 +89,6 @@ pub async fn get_override(&self, name: &str) -> Result { .await .deserialized::>() .map(at!(0)) - .into_iter() - .find(CachedOverride::valid) - .ok_or(err!(Request(NotFound("Expired from cache")))) } #[implement(Cache)] diff --git a/src/service/resolver/dns.rs b/src/service/resolver/dns.rs index ad7768bc..ca6106e2 100644 --- a/src/service/resolver/dns.rs +++ b/src/service/resolver/dns.rs @@ -93,6 +93,11 @@ impl Resolve for Hooked { } } +#[tracing::instrument( + level = "debug", + skip_all, + fields(name = ?name.as_str()) +)] async fn hooked_resolve( cache: Arc, server: Arc, @@ -100,8 +105,21 @@ async fn hooked_resolve( name: Name, ) -> Result> { match cache.get_override(name.as_str()).await { - | Ok(cached) => cached_to_reqwest(cached).await, - | Err(_) => resolve_to_reqwest(server, resolver, name).boxed().await, + | Ok(cached) if cached.valid() => cached_to_reqwest(cached).await, + | Ok(CachedOverride { overriding, .. }) if overriding.is_some() => + resolve_to_reqwest( + server, + resolver, + overriding + .as_deref() + .map(str::parse) + .expect("overriding is set for this record") + .expect("overriding is a valid internet name"), + ) + .boxed() + .await, + + | _ => resolve_to_reqwest(server, resolver, name).boxed().await, } }