modernize async srv lookup

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-11-30 15:13:35 +00:00
parent 9a9c071e82
commit ed8c21ac9a
2 changed files with 16 additions and 32 deletions

View file

@ -1,11 +1,10 @@
use std::{ use std::{
fmt::Debug, fmt::Debug,
net::{IpAddr, SocketAddr}, net::{IpAddr, SocketAddr},
sync::Arc,
}; };
use conduit::{debug, debug_error, debug_info, debug_warn, err, error, trace, Err, Result}; use conduit::{debug, debug_error, debug_info, debug_warn, err, error, trace, Err, Result};
use hickory_resolver::{error::ResolveError, lookup::SrvLookup}; use hickory_resolver::error::ResolveError;
use ipaddress::IPAddress; use ipaddress::IPAddress;
use ruma::ServerName; use ruma::ServerName;
@ -258,7 +257,7 @@ impl super::Service {
#[tracing::instrument(skip_all, name = "ip")] #[tracing::instrument(skip_all, name = "ip")]
async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> {
match self.raw().lookup_ip(hostname.to_owned()).await { match self.resolver.resolver.lookup_ip(hostname.to_owned()).await {
Err(e) => Self::handle_resolve_error(&e), Err(e) => Self::handle_resolve_error(&e),
Ok(override_ip) => { Ok(override_ip) => {
if hostname != overname { if hostname != overname {
@ -281,32 +280,24 @@ impl super::Service {
#[tracing::instrument(skip_all, name = "srv")] #[tracing::instrument(skip_all, name = "srv")]
async fn query_srv_record(&self, hostname: &'_ str) -> Result<Option<FedDest>> { async fn query_srv_record(&self, hostname: &'_ str) -> Result<Option<FedDest>> {
fn handle_successful_srv(srv: &SrvLookup) -> Option<FedDest> {
srv.iter().next().map(|result| {
FedDest::Named(
result.target().to_string().trim_end_matches('.').to_owned(),
format!(":{}", result.port())
.as_str()
.try_into()
.unwrap_or_else(|_| FedDest::default_port()),
)
})
}
async fn lookup_srv(
resolver: Arc<super::TokioAsyncResolver>, hostname: &str,
) -> Result<SrvLookup, ResolveError> {
debug!("querying SRV for {hostname:?}");
let hostname = hostname.trim_end_matches('.');
resolver.srv_lookup(hostname.to_owned()).await
}
let hostnames = [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")]; let hostnames = [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")];
for hostname in hostnames { for hostname in hostnames {
match lookup_srv(self.raw(), &hostname).await { debug!("querying SRV for {hostname:?}");
Ok(result) => return Ok(handle_successful_srv(&result)), let hostname = hostname.trim_end_matches('.');
match self.resolver.resolver.srv_lookup(hostname).await {
Err(e) => Self::handle_resolve_error(&e)?, Err(e) => Self::handle_resolve_error(&e)?,
Ok(result) => {
return Ok(result.iter().next().map(|result| {
FedDest::Named(
result.target().to_string().trim_end_matches('.').to_owned(),
format!(":{}", result.port())
.as_str()
.try_into()
.unwrap_or_else(|_| FedDest::default_port()),
)
}))
},
} }
} }

View file

@ -7,7 +7,6 @@ mod tests;
use std::{fmt::Write, sync::Arc}; use std::{fmt::Write, sync::Arc};
use conduit::{Result, Server}; use conduit::{Result, Server};
use hickory_resolver::TokioAsyncResolver;
use self::{cache::Cache, dns::Resolver}; use self::{cache::Cache, dns::Resolver};
use crate::{client, globals, Dep}; use crate::{client, globals, Dep};
@ -71,9 +70,3 @@ impl crate::Service for Service {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
} }
impl Service {
#[inline]
#[must_use]
pub fn raw(&self) -> Arc<TokioAsyncResolver> { self.resolver.resolver.clone() }
}