split well_known resolver into unit

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2025-03-19 03:49:12 +00:00
parent 8010505853
commit 23e3f6526f
4 changed files with 55 additions and 53 deletions

View file

@ -3,7 +3,7 @@ use std::{
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
}; };
use conduwuit::{Err, Result, debug, debug_error, debug_info, debug_warn, err, error, trace}; use conduwuit::{Err, Result, debug, debug_info, err, error, trace};
use futures::{FutureExt, TryFutureExt}; use futures::{FutureExt, TryFutureExt};
use hickory_resolver::error::ResolveError; use hickory_resolver::error::ResolveError;
use ipaddress::IPAddress; use ipaddress::IPAddress;
@ -72,6 +72,9 @@ impl super::Service {
if let Some(pos) = dest.as_str().find(':') { if let Some(pos) = dest.as_str().find(':') {
self.actual_dest_2(dest, cache, pos).await? self.actual_dest_2(dest, cache, pos).await?
} else { } else {
self.conditional_query_and_cache(dest.as_str(), 8448, true)
.await?;
self.services.server.check_running()?;
match self.request_well_known(dest.as_str()).await? { match self.request_well_known(dest.as_str()).await? {
| Some(delegated) => | Some(delegated) =>
self.actual_dest_3(&mut host, cache, delegated).await?, self.actual_dest_3(&mut host, cache, delegated).await?,
@ -243,56 +246,6 @@ impl super::Service {
Ok(add_port_to_hostname(dest.as_str())) Ok(add_port_to_hostname(dest.as_str()))
} }
#[tracing::instrument(name = "well-known", level = "debug", skip(self, dest))]
async fn request_well_known(&self, dest: &str) -> Result<Option<String>> {
self.conditional_query_and_cache(dest, 8448, true).await?;
self.services.server.check_running()?;
trace!("Requesting well known for {dest}");
let response = self
.services
.client
.well_known
.get(format!("https://{dest}/.well-known/matrix/server"))
.send()
.await;
trace!("response: {response:?}");
if let Err(e) = &response {
debug!("error: {e:?}");
return Ok(None);
}
let response = response?;
if !response.status().is_success() {
debug!("response not 2XX");
return Ok(None);
}
let text = response.text().await?;
trace!("response text: {text:?}");
if text.len() >= 12288 {
debug_warn!("response contains junk");
return Ok(None);
}
let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default();
let m_server = body
.get("m.server")
.unwrap_or(&serde_json::Value::Null)
.as_str()
.unwrap_or_default();
if ruma::identifiers_validation::server_name::validate(m_server).is_err() {
debug_error!("response content missing or invalid");
return Ok(None);
}
debug_info!("{dest:?} found at {m_server:?}");
Ok(Some(m_server.to_owned()))
}
#[inline] #[inline]
async fn conditional_query_and_cache( async fn conditional_query_and_cache(
&self, &self,

View file

@ -2,7 +2,9 @@ pub mod actual;
pub mod cache; pub mod cache;
mod dns; mod dns;
pub mod fed; pub mod fed;
#[cfg(test)]
mod tests; mod tests;
mod well_known;
use std::sync::Arc; use std::sync::Arc;

View file

@ -1,5 +1,3 @@
#![cfg(test)]
use super::fed::{FedDest, add_port_to_hostname, get_ip_with_port}; use super::fed::{FedDest, add_port_to_hostname, get_ip_with_port};
#[test] #[test]

View file

@ -0,0 +1,49 @@
use conduwuit::{Result, debug, debug_error, debug_info, debug_warn, implement, trace};
#[implement(super::Service)]
#[tracing::instrument(name = "well-known", level = "debug", skip(self, dest))]
pub(super) async fn request_well_known(&self, dest: &str) -> Result<Option<String>> {
trace!("Requesting well known for {dest}");
let response = self
.services
.client
.well_known
.get(format!("https://{dest}/.well-known/matrix/server"))
.send()
.await;
trace!("response: {response:?}");
if let Err(e) = &response {
debug!("error: {e:?}");
return Ok(None);
}
let response = response?;
if !response.status().is_success() {
debug!("response not 2XX");
return Ok(None);
}
let text = response.text().await?;
trace!("response text: {text:?}");
if text.len() >= 12288 {
debug_warn!("response contains junk");
return Ok(None);
}
let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default();
let m_server = body
.get("m.server")
.unwrap_or(&serde_json::Value::Null)
.as_str()
.unwrap_or_default();
if ruma::identifiers_validation::server_name::validate(m_server).is_err() {
debug_error!("response content missing or invalid");
return Ok(None);
}
debug_info!("{dest:?} found at {m_server:?}");
Ok(Some(m_server.to_owned()))
}