add faster interruption to resolver (#649)

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-12-30 13:56:21 +00:00 committed by strawberry
parent a1fc4d49ac
commit 9c6b5b4407

View file

@ -2,7 +2,7 @@ use std::{net::SocketAddr, sync::Arc, time::Duration};
use conduwuit::{err, Result, Server};
use futures::FutureExt;
use hickory_resolver::TokioAsyncResolver;
use hickory_resolver::{lookup_ip::LookupIp, TokioAsyncResolver};
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use super::cache::{Cache, CachedOverride};
@ -10,11 +10,13 @@ use super::cache::{Cache, CachedOverride};
pub struct Resolver {
pub(crate) resolver: Arc<TokioAsyncResolver>,
pub(crate) hooked: Arc<Hooked>,
server: Arc<Server>,
}
pub(crate) struct Hooked {
resolver: Arc<TokioAsyncResolver>,
cache: Arc<Cache>,
server: Arc<Server>,
}
type ResolvingResult = Result<Addrs, Box<dyn std::error::Error + Send + Sync>>;
@ -72,14 +74,15 @@ impl Resolver {
let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts));
Ok(Arc::new(Self {
resolver: resolver.clone(),
hooked: Arc::new(Hooked { resolver, cache }),
hooked: Arc::new(Hooked { resolver, cache, server: server.clone() }),
server: server.clone(),
}))
}
}
impl Resolve for Resolver {
fn resolve(&self, name: Name) -> Resolving {
resolve_to_reqwest(self.resolver.clone(), name).boxed()
resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed()
}
}
@ -94,12 +97,29 @@ impl Resolve for Hooked {
.cloned();
cached.map_or_else(
|| resolve_to_reqwest(self.resolver.clone(), name).boxed(),
|| resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed(),
|cached| cached_to_reqwest(cached).boxed(),
)
}
}
async fn resolve_to_reqwest(
server: Arc<Server>,
resolver: Arc<TokioAsyncResolver>,
name: Name,
) -> ResolvingResult {
use std::{io, io::ErrorKind::Interrupted};
let handle_shutdown = || Box::new(io::Error::new(Interrupted, "Server shutting down"));
let handle_results =
|results: LookupIp| Box::new(results.into_iter().map(|ip| SocketAddr::new(ip, 0)));
tokio::select! {
results = resolver.lookup_ip(name.as_str()) => Ok(handle_results(results?)),
() = server.until_shutdown() => Err(handle_shutdown()),
}
}
async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult {
let addrs = cached
.ips
@ -108,13 +128,3 @@ async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult {
Ok(Box::new(addrs))
}
async fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> ResolvingResult {
let results = resolver
.lookup_ip(name.as_str())
.await?
.into_iter()
.map(|ip| SocketAddr::new(ip, 0));
Ok(Box::new(results))
}