slightly simplify reqwest/hickory hooks

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-10-27 19:17:41 +00:00
parent 9787dfe77c
commit e7e606300f

View file

@ -1,15 +1,11 @@
use std::{
future, iter,
net::{IpAddr, SocketAddr},
sync::Arc,
time::Duration,
};
use std::{iter, net::SocketAddr, sync::Arc, time::Duration};
use conduit::{err, Result, Server};
use futures::FutureExt;
use hickory_resolver::TokioAsyncResolver;
use reqwest::dns::{Addrs, Name, Resolve, Resolving};
use super::cache::Cache;
use super::cache::{Cache, CachedOverride};
pub struct Resolver {
pub(crate) resolver: Arc<TokioAsyncResolver>,
@ -21,6 +17,8 @@ pub(crate) struct Hooked {
cache: Arc<Cache>,
}
type ResolvingResult = Result<Addrs, Box<dyn std::error::Error + Send + Sync>>;
impl Resolver {
#[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)]
pub(super) fn build(server: &Arc<Server>, cache: Arc<Cache>) -> Result<Arc<Self>> {
@ -82,12 +80,12 @@ impl Resolver {
}
impl Resolve for Resolver {
fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) }
fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name).boxed() }
}
impl Resolve for Hooked {
fn resolve(&self, name: Name) -> Resolving {
let cached = self
let cached: Option<CachedOverride> = self
.cache
.overrides
.read()
@ -95,35 +93,30 @@ impl Resolve for Hooked {
.get(name.as_str())
.cloned();
if let Some(cached) = cached {
cached_to_reqwest(&cached.ips, cached.port)
} else {
resolve_to_reqwest(self.resolver.clone(), name)
}
cached.map_or_else(
|| resolve_to_reqwest(self.resolver.clone(), name).boxed(),
|cached| cached_to_reqwest(cached).boxed(),
)
}
}
fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving {
override_name
async fn cached_to_reqwest(cached: CachedOverride) -> ResolvingResult {
let first_ip = cached
.ips
.first()
.map(|first_name| -> Resolving {
let saddr = SocketAddr::new(*first_name, port);
let result: Box<dyn Iterator<Item = SocketAddr> + Send> = Box::new(iter::once(saddr));
Box::pin(future::ready(Ok(result)))
})
.expect("must provide at least one override name")
.expect("must provide at least one override");
let saddr = SocketAddr::new(*first_ip, cached.port);
Ok(Box::new(iter::once(saddr)))
}
fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> Resolving {
Box::pin(async move {
let results = resolver
.lookup_ip(name.as_str())
.await?
.into_iter()
.map(|ip| SocketAddr::new(ip, 0));
async fn resolve_to_reqwest(resolver: Arc<TokioAsyncResolver>, name: Name) -> ResolvingResult {
let results = resolver
.lookup_ip(name.as_str())
.await?
.into_iter()
.map(|ip| SocketAddr::new(ip, 0));
let results: Addrs = Box::new(results);
Ok(results)
})
Ok(Box::new(results))
}