From e7e606300f33410bfb6bfdf7c9671b210e37f287 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sun, 27 Oct 2024 19:17:41 +0000 Subject: [PATCH] slightly simplify reqwest/hickory hooks Signed-off-by: Jason Volk --- src/service/resolver/dns.rs | 59 ++++++++++++++++--------------------- 1 file changed, 26 insertions(+), 33 deletions(-) diff --git a/src/service/resolver/dns.rs b/src/service/resolver/dns.rs index b77bbb84..89129e03 100644 --- a/src/service/resolver/dns.rs +++ b/src/service/resolver/dns.rs @@ -1,15 +1,11 @@ -use std::{ - future, iter, - net::{IpAddr, SocketAddr}, - sync::Arc, - time::Duration, -}; +use std::{iter, net::SocketAddr, sync::Arc, time::Duration}; use conduit::{err, Result, Server}; +use futures::FutureExt; use hickory_resolver::TokioAsyncResolver; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; -use super::cache::Cache; +use super::cache::{Cache, CachedOverride}; pub struct Resolver { pub(crate) resolver: Arc, @@ -21,6 +17,8 @@ pub(crate) struct Hooked { cache: Arc, } +type ResolvingResult = Result>; + impl Resolver { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] pub(super) fn build(server: &Arc, cache: Arc) -> Result> { @@ -82,12 +80,12 @@ impl Resolver { } impl Resolve for Resolver { - fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } + fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name).boxed() } } impl Resolve for Hooked { fn resolve(&self, name: Name) -> Resolving { - let cached = self + let cached: Option = self .cache .overrides .read() @@ -95,35 +93,30 @@ impl Resolve for Hooked { .get(name.as_str()) .cloned(); - if let Some(cached) = cached { - cached_to_reqwest(&cached.ips, cached.port) - } else { - resolve_to_reqwest(self.resolver.clone(), name) - } + cached.map_or_else( + || resolve_to_reqwest(self.resolver.clone(), name).boxed(), + |cached| cached_to_reqwest(cached).boxed(), + ) } } -fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { - override_name +async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult { + let first_ip = cached + .ips .first() - .map(|first_name| -> Resolving { - let saddr = SocketAddr::new(*first_name, port); - let result: Box + Send> = Box::new(iter::once(saddr)); - Box::pin(future::ready(Ok(result))) - }) - .expect("must provide at least one override name") + .expect("must provide at least one override"); + + let saddr = SocketAddr::new(*first_ip, cached.port); + + Ok(Box::new(iter::once(saddr))) } -fn resolve_to_reqwest(resolver: Arc, name: Name) -> Resolving { - Box::pin(async move { - let results = resolver - .lookup_ip(name.as_str()) - .await? - .into_iter() - .map(|ip| SocketAddr::new(ip, 0)); +async fn resolve_to_reqwest(resolver: Arc, name: Name) -> ResolvingResult { + let results = resolver + .lookup_ip(name.as_str()) + .await? + .into_iter() + .map(|ip| SocketAddr::new(ip, 0)); - let results: Addrs = Box::new(results); - - Ok(results) - }) + Ok(Box::new(results)) }