fix SRV override loss on cache expiration

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-01-22 23:07:13 +00:00
parent 265802d546
commit a5520e8b1b
4 changed files with 63 additions and 40 deletions

View file

@ -51,12 +51,14 @@ async fn destinations_cache(
async fn overrides_cache(&self, server_name: Option<String>) -> Result<RoomMessageEventContent> { async fn overrides_cache(&self, server_name: Option<String>) -> Result<RoomMessageEventContent> {
use service::resolver::cache::CachedOverride; use service::resolver::cache::CachedOverride;
writeln!(self, "| Server Name | IP | Port | Expires |").await?; writeln!(self, "| Server Name | IP | Port | Expires | Overriding |").await?;
writeln!(self, "| ----------- | --- | ----:| ------- |").await?; writeln!(self, "| ----------- | --- | ----:| ------- | ---------- |").await?;
let mut overrides = self.services.resolver.cache.overrides().boxed(); let mut overrides = self.services.resolver.cache.overrides().boxed();
while let Some((name, CachedOverride { ips, port, expire })) = overrides.next().await { while let Some((name, CachedOverride { ips, port, expire, overriding })) =
overrides.next().await
{
if let Some(server_name) = server_name.as_ref() { if let Some(server_name) = server_name.as_ref() {
if name != server_name { if name != server_name {
continue; continue;
@ -64,7 +66,7 @@ async fn overrides_cache(&self, server_name: Option<String>) -> Result<RoomMessa
} }
let expire = time::format(expire, "%+"); let expire = time::format(expire, "%+");
self.write_str(&format!("| {name} | {ips:?} | {port} | {expire} |\n")) self.write_str(&format!("| {name} | {ips:?} | {port} | {expire} | {overriding:?} |\n"))
.await?; .await?;
} }

View file

@ -112,12 +112,7 @@ impl super::Service {
async fn actual_dest_2(&self, dest: &ServerName, cache: bool, pos: usize) -> Result<FedDest> { async fn actual_dest_2(&self, dest: &ServerName, cache: bool, pos: usize) -> Result<FedDest> {
debug!("2: Hostname with included port"); debug!("2: Hostname with included port");
let (host, port) = dest.as_str().split_at(pos); let (host, port) = dest.as_str().split_at(pos);
self.conditional_query_and_cache_override( self.conditional_query_and_cache(host, port.parse::<u16>().unwrap_or(8448), cache)
host,
host,
port.parse::<u16>().unwrap_or(8448),
cache,
)
.await?; .await?;
Ok(FedDest::Named( Ok(FedDest::Named(
@ -163,12 +158,7 @@ impl super::Service {
) -> Result<FedDest> { ) -> Result<FedDest> {
debug!("3.2: Hostname with port in .well-known file"); debug!("3.2: Hostname with port in .well-known file");
let (host, port) = delegated.split_at(pos); let (host, port) = delegated.split_at(pos);
self.conditional_query_and_cache_override( self.conditional_query_and_cache(host, port.parse::<u16>().unwrap_or(8448), cache)
host,
host,
port.parse::<u16>().unwrap_or(8448),
cache,
)
.await?; .await?;
Ok(FedDest::Named( Ok(FedDest::Named(
@ -208,7 +198,7 @@ impl super::Service {
async fn actual_dest_3_4(&self, cache: bool, delegated: String) -> Result<FedDest> { async fn actual_dest_3_4(&self, cache: bool, delegated: String) -> Result<FedDest> {
debug!("3.4: No SRV records, just use the hostname from .well-known"); debug!("3.4: No SRV records, just use the hostname from .well-known");
self.conditional_query_and_cache_override(&delegated, &delegated, 8448, cache) self.conditional_query_and_cache(&delegated, 8448, cache)
.await?; .await?;
Ok(add_port_to_hostname(&delegated)) Ok(add_port_to_hostname(&delegated))
} }
@ -243,7 +233,7 @@ impl super::Service {
async fn actual_dest_5(&self, dest: &ServerName, cache: bool) -> Result<FedDest> { async fn actual_dest_5(&self, dest: &ServerName, cache: bool) -> Result<FedDest> {
debug!("5: No SRV record found"); debug!("5: No SRV record found");
self.conditional_query_and_cache_override(dest.as_str(), dest.as_str(), 8448, cache) self.conditional_query_and_cache(dest.as_str(), 8448, cache)
.await?; .await?;
Ok(add_port_to_hostname(dest.as_str())) Ok(add_port_to_hostname(dest.as_str()))
@ -251,9 +241,7 @@ impl super::Service {
#[tracing::instrument(skip_all, name = "well-known")] #[tracing::instrument(skip_all, name = "well-known")]
async fn request_well_known(&self, dest: &str) -> Result<Option<String>> { async fn request_well_known(&self, dest: &str) -> Result<Option<String>> {
if !self.cache.has_override(dest).await { self.conditional_query_and_cache(dest, 8448, true).await?;
self.query_and_cache_override(dest, dest, 8448).await?;
}
self.services.server.check_running()?; self.services.server.check_running()?;
trace!("Requesting well known for {dest}"); trace!("Requesting well known for {dest}");
@ -301,6 +289,17 @@ impl super::Service {
Ok(Some(m_server.to_owned())) Ok(Some(m_server.to_owned()))
} }
#[inline]
async fn conditional_query_and_cache(
&self,
hostname: &str,
port: u16,
cache: bool,
) -> Result {
self.conditional_query_and_cache_override(hostname, hostname, port, cache)
.await
}
#[inline] #[inline]
async fn conditional_query_and_cache_override( async fn conditional_query_and_cache_override(
&self, &self,
@ -308,13 +307,17 @@ impl super::Service {
hostname: &str, hostname: &str,
port: u16, port: u16,
cache: bool, cache: bool,
) -> Result<()> { ) -> Result {
if cache { if !cache {
return Ok(());
}
if self.cache.has_override(overname).await {
return Ok(());
}
self.query_and_cache_override(overname, hostname, port) self.query_and_cache_override(overname, hostname, port)
.await .await
} else {
Ok(())
}
} }
#[tracing::instrument(skip(self, overname, port), name = "ip")] #[tracing::instrument(skip(self, overname, port), name = "ip")]
@ -323,21 +326,20 @@ impl super::Service {
overname: &'_ str, overname: &'_ str,
hostname: &'_ str, hostname: &'_ str,
port: u16, port: u16,
) -> Result<()> { ) -> Result {
self.services.server.check_running()?; self.services.server.check_running()?;
debug!("querying IP for {overname:?} ({hostname:?}:{port})"); debug!("querying IP for {overname:?} ({hostname:?}:{port})");
match self.resolver.resolver.lookup_ip(hostname.to_owned()).await { match self.resolver.resolver.lookup_ip(hostname.to_owned()).await {
| Err(e) => Self::handle_resolve_error(&e, hostname), | Err(e) => Self::handle_resolve_error(&e, hostname),
| Ok(override_ip) => { | Ok(override_ip) => {
if hostname != overname {
debug_info!("{overname:?} overriden by {hostname:?}");
}
self.cache.set_override(overname, &CachedOverride { self.cache.set_override(overname, &CachedOverride {
ips: override_ip.into_iter().take(MAX_IPS).collect(), ips: override_ip.into_iter().take(MAX_IPS).collect(),
port, port,
expire: CachedOverride::default_expire(), expire: CachedOverride::default_expire(),
overriding: (hostname != overname)
.then_some(hostname.into())
.inspect(|_| debug_info!("{overname:?} overriden by {hostname:?}")),
}); });
Ok(()) Ok(())

View file

@ -30,6 +30,7 @@ pub struct CachedOverride {
pub ips: IpAddrs, pub ips: IpAddrs,
pub port: u16, pub port: u16,
pub expire: SystemTime, pub expire: SystemTime,
pub overriding: Option<String>,
} }
pub type IpAddrs = ArrayVec<IpAddr, MAX_IPS>; pub type IpAddrs = ArrayVec<IpAddr, MAX_IPS>;
@ -63,7 +64,10 @@ pub async fn has_destination(&self, destination: &ServerName) -> bool {
#[implement(Cache)] #[implement(Cache)]
#[must_use] #[must_use]
pub async fn has_override(&self, destination: &str) -> bool { pub async fn has_override(&self, destination: &str) -> bool {
self.get_override(destination).await.is_ok() self.get_override(destination)
.await
.iter()
.any(CachedOverride::valid)
} }
#[implement(Cache)] #[implement(Cache)]
@ -85,9 +89,6 @@ pub async fn get_override(&self, name: &str) -> Result<CachedOverride> {
.await .await
.deserialized::<Cbor<_>>() .deserialized::<Cbor<_>>()
.map(at!(0)) .map(at!(0))
.into_iter()
.find(CachedOverride::valid)
.ok_or(err!(Request(NotFound("Expired from cache"))))
} }
#[implement(Cache)] #[implement(Cache)]

View file

@ -93,6 +93,11 @@ impl Resolve for Hooked {
} }
} }
#[tracing::instrument(
level = "debug",
skip_all,
fields(name = ?name.as_str())
)]
async fn hooked_resolve( async fn hooked_resolve(
cache: Arc<Cache>, cache: Arc<Cache>,
server: Arc<Server>, server: Arc<Server>,
@ -100,8 +105,21 @@ async fn hooked_resolve(
name: Name, name: Name,
) -> Result<Addrs, Box<dyn std::error::Error + Send + Sync>> { ) -> Result<Addrs, Box<dyn std::error::Error + Send + Sync>> {
match cache.get_override(name.as_str()).await { match cache.get_override(name.as_str()).await {
| Ok(cached) => cached_to_reqwest(cached).await, | Ok(cached) if cached.valid() => cached_to_reqwest(cached).await,
| Err(_) => resolve_to_reqwest(server, resolver, name).boxed().await, | Ok(CachedOverride { overriding, .. }) if overriding.is_some() =>
resolve_to_reqwest(
server,
resolver,
overriding
.as_deref()
.map(str::parse)
.expect("overriding is set for this record")
.expect("overriding is a valid internet name"),
)
.boxed()
.await,
| _ => resolve_to_reqwest(server, resolver, name).boxed().await,
} }
} }