add faster interruption to resolver (#649)

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-12-30 13:56:21 +00:00 committed by strawberry
parent a1fc4d49ac
commit 9c6b5b4407

View file

@ -2,7 +2,7 @@ use std::{net::SocketAddr, sync::Arc, time::Duration};
use conduwuit::{err, Result, Server}; use conduwuit::{err, Result, Server};
use futures::FutureExt; use futures::FutureExt;
use hickory_resolver::TokioAsyncResolver; use hickory_resolver::{lookup_ip::LookupIp, TokioAsyncResolver};
use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use super::cache::{Cache, CachedOverride}; use super::cache::{Cache, CachedOverride};
@ -10,11 +10,13 @@ use super::cache::{Cache, CachedOverride};
pub struct Resolver { pub struct Resolver {
pub(crate) resolver: Arc<TokioAsyncResolver>, pub(crate) resolver: Arc<TokioAsyncResolver>,
pub(crate) hooked: Arc<Hooked>, pub(crate) hooked: Arc<Hooked>,
server: Arc<Server>,
} }
pub(crate) struct Hooked { pub(crate) struct Hooked {
resolver: Arc<TokioAsyncResolver>, resolver: Arc<TokioAsyncResolver>,
cache: Arc<Cache>, cache: Arc<Cache>,
server: Arc<Server>,
} }
type ResolvingResult = Result<Addrs, Box<dyn std::error::Error + Send + Sync>>; type ResolvingResult = Result<Addrs, Box<dyn std::error::Error + Send + Sync>>;
@ -72,14 +74,15 @@ impl Resolver {
let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts));
Ok(Arc::new(Self { Ok(Arc::new(Self {
resolver: resolver.clone(), resolver: resolver.clone(),
hooked: Arc::new(Hooked { resolver, cache }), hooked: Arc::new(Hooked { resolver, cache, server: server.clone() }),
server: server.clone(),
})) }))
} }
} }
impl Resolve for Resolver { impl Resolve for Resolver {
fn resolve(&self, name: Name) -> Resolving { fn resolve(&self, name: Name) -> Resolving {
resolve_to_reqwest(self.resolver.clone(), name).boxed() resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed()
} }
} }
@ -94,12 +97,29 @@ impl Resolve for Hooked {
.cloned(); .cloned();
cached.map_or_else( cached.map_or_else(
|| resolve_to_reqwest(self.resolver.clone(), name).boxed(), || resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed(),
|cached| cached_to_reqwest(cached).boxed(), |cached| cached_to_reqwest(cached).boxed(),
) )
} }
} }
async fn resolve_to_reqwest(
server: Arc<Server>,
resolver: Arc<TokioAsyncResolver>,
name: Name,
) -> ResolvingResult {
use std::{io, io::ErrorKind::Interrupted};
let handle_shutdown = || Box::new(io::Error::new(Interrupted, "Server shutting down"));
let handle_results =
|results: LookupIp| Box::new(results.into_iter().map(|ip| SocketAddr::new(ip, 0)));
tokio::select! {
results = resolver.lookup_ip(name.as_str()) => Ok(handle_results(results?)),
() = server.until_shutdown() => Err(handle_shutdown()),
}
}
async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult { async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult {
let addrs = cached let addrs = cached
.ips .ips
@ -108,13 +128,3 @@ async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult {
Ok(Box::new(addrs)) Ok(Box::new(addrs))
} }
async fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> ResolvingResult {
let results = resolver
.lookup_ip(name.as_str())
.await?
.into_iter()
.map(|ip| SocketAddr::new(ip, 0));
Ok(Box::new(results))
}