diff --git a/src/service/resolver/actual.rs b/src/service/resolver/actual.rs index b037cf77..1ad76f66 100644 --- a/src/service/resolver/actual.rs +++ b/src/service/resolver/actual.rs @@ -3,7 +3,7 @@ use std::{ net::{IpAddr, SocketAddr}, }; -use conduwuit::{Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace}; +use conduwuit::{Err, Result, debug, debug_info, err, error, trace}; use futures::{FutureExt, TryFutureExt}; use hickory_resolver::error::ResolveError; use ipaddress::IPAddress; @@ -72,6 +72,9 @@ impl super::Service { if let Some(pos) = dest.as_str().find(':') { self.actual_dest_2(dest, cache, pos).await? } else { + self.conditional_query_and_cache(dest.as_str(), 8448, true) + .await?; + self.services.server.check_running()?; match self.request_well_known(dest.as_str()).await? { | Some(delegated) => self.actual_dest_3(&mut host, cache, delegated).await?, @@ -243,56 +246,6 @@ impl super::Service { Ok(add_port_to_hostname(dest.as_str())) } - #[tracing::instrument(name = "well-known", level = "debug", skip(self, dest))] - async fn request_well_known(&self, dest: &str) -> Result> { - self.conditional_query_and_cache(dest, 8448, true).await?; - - self.services.server.check_running()?; - trace!("Requesting well known for {dest}"); - let response = self - .services - .client - .well_known - .get(format!("https://{dest}/.well-known/matrix/server")) - .send() - .await; - - trace!("response: {response:?}"); - if let Err(e) = &response { - debug!("error: {e:?}"); - return Ok(None); - } - - let response = response?; - if !response.status().is_success() { - debug!("response not 2XX"); - return Ok(None); - } - - let text = response.text().await?; - trace!("response text: {text:?}"); - if text.len() >= 12288 { - debug_warn!("response contains junk"); - return Ok(None); - } - - let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default(); - - let m_server = body - .get("m.server") - .unwrap_or(&serde_json::Value::Null) - .as_str() - .unwrap_or_default(); - - if ruma::identifiers_validation::server_name::validate(m_server).is_err() { - debug_error!("response content missing or invalid"); - return Ok(None); - } - - debug_info!("{dest:?} found at {m_server:?}"); - Ok(Some(m_server.to_owned())) - } - #[inline] async fn conditional_query_and_cache( &self, diff --git a/src/service/resolver/mod.rs b/src/service/resolver/mod.rs index 246d6bc1..c513cec9 100644 --- a/src/service/resolver/mod.rs +++ b/src/service/resolver/mod.rs @@ -2,7 +2,9 @@ pub mod actual; pub mod cache; mod dns; pub mod fed; +#[cfg(test)] mod tests; +mod well_known; use std::sync::Arc; diff --git a/src/service/resolver/tests.rs b/src/service/resolver/tests.rs index 6e9d0e71..068e08bd 100644 --- a/src/service/resolver/tests.rs +++ b/src/service/resolver/tests.rs @@ -1,5 +1,3 @@ -#![cfg(test)] - use super::fed::{FedDest, add_port_to_hostname, get_ip_with_port}; #[test] diff --git a/src/service/resolver/well_known.rs b/src/service/resolver/well_known.rs new file mode 100644 index 00000000..68a8e620 --- /dev/null +++ b/src/service/resolver/well_known.rs @@ -0,0 +1,49 @@ +use conduwuit::{Result, debug, debug_error, debug_info, debug_warn, implement, trace}; + +#[implement(super::Service)] +#[tracing::instrument(name = "well-known", level = "debug", skip(self, dest))] +pub(super) async fn request_well_known(&self, dest: &str) -> Result> { + trace!("Requesting well known for {dest}"); + let response = self + .services + .client + .well_known + .get(format!("https://{dest}/.well-known/matrix/server")) + .send() + .await; + + trace!("response: {response:?}"); + if let Err(e) = &response { + debug!("error: {e:?}"); + return Ok(None); + } + + let response = response?; + if !response.status().is_success() { + debug!("response not 2XX"); + return Ok(None); + } + + let text = response.text().await?; + trace!("response text: {text:?}"); + if text.len() >= 12288 { + debug_warn!("response contains junk"); + return Ok(None); + } + + let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default(); + + let m_server = body + .get("m.server") + .unwrap_or(&serde_json::Value::Null) + .as_str() + .unwrap_or_default(); + + if ruma::identifiers_validation::server_name::validate(m_server).is_err() { + debug_error!("response content missing or invalid"); + return Ok(None); + } + + debug_info!("{dest:?} found at {m_server:?}"); + Ok(Some(m_server.to_owned())) +}