use std::{ collections::HashMap, future, iter, net::{IpAddr, SocketAddr}, sync::{Arc, RwLock}, time::Duration, }; use conduit::{error, Config, Error}; use hickory_resolver::TokioAsyncResolver; use reqwest::dns::{Addrs, Name, Resolve, Resolving}; use ruma::OwnedServerName; use crate::sending::FedDest; type WellKnownMap = HashMap; type TlsNameMap = HashMap, u16)>; pub struct Resolver { pub destinations: Arc>, // actual_destination, host pub overrides: Arc>, pub resolver: Arc, pub hooked: Arc, } pub struct Hooked { pub overrides: Arc>, pub resolver: Arc, } impl Resolver { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] pub fn new(config: &Config) -> Self { let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() .map_err(|e| { error!("Failed to set up hickory dns resolver with system config: {e}"); Error::bad_config("Failed to set up hickory dns resolver with system config.") }) .expect("DNS system config must be valid"); let mut conf = hickory_resolver::config::ResolverConfig::new(); if let Some(domain) = sys_conf.domain() { conf.set_domain(domain.clone()); } for sys_conf in sys_conf.search() { conf.add_search(sys_conf.clone()); } for sys_conf in sys_conf.name_servers() { let mut ns = sys_conf.clone(); if config.query_over_tcp_only { ns.protocol = hickory_resolver::config::Protocol::Tcp; } ns.trust_negative_responses = !config.query_all_nameservers; conf.add_name_server(ns); } opts.cache_size = config.dns_cache_entries as usize; opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain)); opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30)); opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl)); opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7)); opts.timeout = Duration::from_secs(config.dns_timeout); opts.attempts = config.dns_attempts as usize; opts.try_tcp_on_error = config.dns_tcp_fallback; opts.num_concurrent_reqs = 1; opts.shuffle_dns_servers = true; opts.rotate = true; opts.ip_strategy = match config.ip_lookup_strategy { 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only, 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only, 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6, 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4, _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6, }; opts.authentic_data = false; let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); let overrides = Arc::new(RwLock::new(TlsNameMap::new())); Self { destinations: Arc::new(RwLock::new(WellKnownMap::new())), overrides: overrides.clone(), resolver: resolver.clone(), hooked: Arc::new(Hooked { overrides, resolver, }), } } } impl Resolve for Resolver { fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } } impl Resolve for Hooked { fn resolve(&self, name: Name) -> Resolving { let addr_port = self .overrides .read() .expect("locked for reading") .get(name.as_str()) .cloned(); if let Some((addr, port)) = addr_port { cached_to_reqwest(&addr, port) } else { resolve_to_reqwest(self.resolver.clone(), name) } } } fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { override_name .first() .map(|first_name| -> Resolving { let saddr = SocketAddr::new(*first_name, port); let result: Box + Send> = Box::new(iter::once(saddr)); Box::pin(future::ready(Ok(result))) }) .expect("must provide at least one override name") } fn resolve_to_reqwest(resolver: Arc, name: Name) -> Resolving { Box::pin(async move { let results = resolver .lookup_ip(name.as_str()) .await? .into_iter() .map(|ip| SocketAddr::new(ip, 0)); let results: Addrs = Box::new(results); Ok(results) }) }