diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index 5676d7b1..afe5a1e5 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -4,7 +4,7 @@ use std::{ }; use conduwuit::{debug, debug_error, debug_info, debug_warn, err, error, trace, Err, Result}; -use futures::FutureExt; +use futures::{FutureExt, TryFutureExt}; use hickory_resolver::error::ResolveError; use ipaddress::IPAddress; use ruma::ServerName; @@ -29,18 +29,31 @@ impl ActualDest { impl super::Service { #[tracing::instrument(skip_all, level = "debug", name = "resolve")] pub(crate) async fn get_actual_dest(&self, server_name: &ServerName) -> Result { - let (result, cached) = if let Ok(result) = self.cache.get_destination(server_name).await { - (result, true) - } else { - self.validate_dest(server_name)?; - (self.resolve_actual_dest(server_name, true).boxed().await?, false) - }; - - let CachedDest { dest, host, .. } = result; + let (CachedDest { dest, host, .. }, cached) = + self.lookup_actual_dest(server_name).await?; Ok(ActualDest { dest, host, cached }) } + pub(crate) async fn lookup_actual_dest( + &self, + server_name: &ServerName, + ) -> Result<(CachedDest, bool)> { + if let Ok(result) = self.cache.get_destination(server_name).await { + return Ok((result, true)); + } + + let _dedup = self.resolving.lock(server_name.as_str()); + if let Ok(result) = self.cache.get_destination(server_name).await { + return Ok((result, true)); + } + + self.resolve_actual_dest(server_name, true) + .map_ok(|result| (result, false)) + .boxed() + .await + } + /// Returns: `actual_destination`, host header /// Implemented according to the specification at /// Numbers in comments below refer to bullet points in linked section of @@ -51,7 +64,7 @@ impl super::Service { dest: &ServerName, cache: bool, ) -> Result { - trace!("Finding actual destination for {dest}"); + self.validate_dest(dest)?; let mut host = dest.as_str().to_owned(); let actual_dest = match get_ip_with_port(dest.as_str()) { | Some(host_port) => Self::actual_dest_1(host_port)?, @@ -106,6 +119,7 @@ impl super::Service { cache, ) .await?; + Ok(FedDest::Named( host.to_owned(), port.try_into().unwrap_or_else(|_| FedDest::default_port()), @@ -156,6 +170,7 @@ impl super::Service { cache, ) .await?; + Ok(FedDest::Named( host.to_owned(), port.try_into().unwrap_or_else(|_| FedDest::default_port()), @@ -177,17 +192,18 @@ impl super::Service { cache, ) .await?; + if let Some(port) = force_port { - Ok(FedDest::Named( + return Ok(FedDest::Named( delegated, format!(":{port}") .as_str() .try_into() .unwrap_or_else(|_| FedDest::default_port()), - )) - } else { - Ok(add_port_to_hostname(&delegated)) + )); } + + Ok(add_port_to_hostname(&delegated)) } async fn actual_dest_3_4(&self, cache: bool, delegated: String) -> Result { @@ -212,21 +228,24 @@ impl super::Service { cache, ) .await?; + if let Some(port) = force_port { let port = format!(":{port}"); - Ok(FedDest::Named( + + return Ok(FedDest::Named( host.to_owned(), PortString::from(port.as_str()).unwrap_or_else(|_| FedDest::default_port()), - )) - } else { - Ok(add_port_to_hostname(host)) + )); } + + Ok(add_port_to_hostname(host)) } async fn actual_dest_5(&self, dest: &ServerName, cache: bool) -> Result { debug!("5: No SRV record found"); self.conditional_query_and_cache_override(dest.as_str(), dest.as_str(), 8448, cache) .await?; + Ok(add_port_to_hostname(dest.as_str())) } diff --git a/src/service/resolver/mod.rs b/src/service/resolver/mod.rs index 3163b0d0..090e562d 100644 --- a/src/service/resolver/mod.rs +++ b/src/service/resolver/mod.rs @@ -6,7 +6,8 @@ mod tests; use std::sync::Arc; -use conduwuit::{Result, Server}; +use arrayvec::ArrayString; +use conduwuit::{utils::MutexMap, Result, Server}; use self::{cache::Cache, dns::Resolver}; use crate::{client, Dep}; @@ -14,6 +15,7 @@ use crate::{client, Dep}; pub struct Service { pub cache: Arc, pub resolver: Arc, + resolving: Resolving, services: Services, } @@ -22,6 +24,9 @@ struct Services { client: Dep, } +type Resolving = MutexMap; +type NameBuf = ArrayString<256>; + impl crate::Service for Service { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] fn build(args: crate::Args<'_>) -> Result> { @@ -29,6 +34,7 @@ impl crate::Service for Service { Ok(Arc::new(Self { cache: cache.clone(), resolver: Resolver::build(args.server, cache)?, + resolving: MutexMap::new(), services: Services { server: args.server.clone(), client: args.depend::("client"),