From 9c6b5b44070c16a72a5149b4a0f104d2ef6ba9dd Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Mon, 30 Dec 2024 13:56:21 +0000 Subject: [PATCH] add faster interruption to resolver (#649) Signed-off-by: Jason Volk --- src/service/resolver/dns.rs | 38 +++++++++++++++++++++++-------------- 1 file changed, 24 insertions(+), 14 deletions(-) diff --git a/src/service/resolver/dns.rs b/src/service/resolver/dns.rs index c331dfba..5c9018ab 100644 --- a/src/service/resolver/dns.rs +++ b/src/service/resolver/dns.rs @@ -2,7 +2,7 @@ use std::{net::SocketAddr, sync::Arc, time::Duration}; use conduwuit::{err, Result, Server}; use futures::FutureExt; -use hickory_resolver::TokioAsyncResolver; +use hickory_resolver::{lookup_ip::LookupIp, TokioAsyncResolver}; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use super::cache::{Cache, CachedOverride}; @@ -10,11 +10,13 @@ use super::cache::{Cache, CachedOverride}; pub struct Resolver { pub(crate) resolver: Arc, pub(crate) hooked: Arc, + server: Arc, } pub(crate) struct Hooked { resolver: Arc, cache: Arc, + server: Arc, } type ResolvingResult = Result>; @@ -72,14 +74,15 @@ impl Resolver { let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); Ok(Arc::new(Self { resolver: resolver.clone(), - hooked: Arc::new(Hooked { resolver, cache }), + hooked: Arc::new(Hooked { resolver, cache, server: server.clone() }), + server: server.clone(), })) } } impl Resolve for Resolver { fn resolve(&self, name: Name) -> Resolving { - resolve_to_reqwest(self.resolver.clone(), name).boxed() + resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed() } } @@ -94,12 +97,29 @@ impl Resolve for Hooked { .cloned(); cached.map_or_else( - || resolve_to_reqwest(self.resolver.clone(), name).boxed(), + || resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed(), |cached| cached_to_reqwest(cached).boxed(), ) } } +async fn resolve_to_reqwest( + server: Arc, + resolver: Arc, + name: Name, +) -> ResolvingResult { + use std::{io, io::ErrorKind::Interrupted}; + + let handle_shutdown = || Box::new(io::Error::new(Interrupted, "Server shutting down")); + let handle_results = + |results: LookupIp| Box::new(results.into_iter().map(|ip| SocketAddr::new(ip, 0))); + + tokio::select! { + results = resolver.lookup_ip(name.as_str()) => Ok(handle_results(results?)), + () = server.until_shutdown() => Err(handle_shutdown()), + } +} + async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult { let addrs = cached .ips @@ -108,13 +128,3 @@ async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult { Ok(Box::new(addrs)) } - -async fn resolve_to_reqwest(resolver: Arc, name: Name) -> ResolvingResult { - let results = resolver - .lookup_ip(name.as_str()) - .await? - .into_iter() - .map(|ip| SocketAddr::new(ip, 0)); - - Ok(Box::new(results)) -}