From 3ccd9ea326cefaf80946b4c334f314d3c87b0598 Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Tue, 16 Jul 2024 23:38:48 +0000 Subject: [PATCH] consolidate all resolution in resolver; split units Signed-off-by: Jason Volk --- src/admin/debug/commands.rs | 7 +- src/admin/query/resolver.rs | 13 +- src/service/client/mod.rs | 14 +- src/service/resolver/actual.rs | 356 +++++++++++++++++++++++++++++++++ src/service/resolver/cache.rs | 114 +++++++++++ src/service/resolver/dns.rs | 129 ++++++++++++ src/service/resolver/fed.rs | 70 +++++++ src/service/resolver/mod.rs | 351 ++++---------------------------- src/service/resolver/tests.rs | 43 ++++ src/service/sending/mod.rs | 2 - src/service/sending/send.rs | 11 +- 11 files changed, 774 insertions(+), 336 deletions(-) create mode 100644 src/service/resolver/actual.rs create mode 100644 src/service/resolver/cache.rs create mode 100644 src/service/resolver/dns.rs create mode 100644 src/service/resolver/fed.rs create mode 100644 src/service/resolver/tests.rs diff --git a/src/admin/debug/commands.rs b/src/admin/debug/commands.rs index 7c7f9331..f319f5a5 100644 --- a/src/admin/debug/commands.rs +++ b/src/admin/debug/commands.rs @@ -15,7 +15,7 @@ use ruma::{ events::room::message::RoomMessageEventContent, CanonicalJsonObject, EventId, OwnedRoomOrAliasId, RoomId, RoomVersionId, ServerName, }; -use service::{rooms::event_handler::parse_incoming_pdu, sending::resolve_actual_dest, services, PduEvent}; +use service::{rooms::event_handler::parse_incoming_pdu, services, PduEvent}; use tokio::sync::RwLock; use tracing_subscriber::EnvFilter; @@ -628,7 +628,10 @@ pub(super) async fn resolve_true_destination( let capture = Capture::new(state, Some(filter), capture::fmt_markdown(logs.clone())); let capture_scope = capture.start(); - let actual = resolve_actual_dest(&server_name, !no_cache).await?; + let actual = services() + .resolver + .resolve_actual_dest(&server_name, !no_cache) + .await?; drop(capture_scope); let msg = format!( diff --git a/src/admin/query/resolver.rs b/src/admin/query/resolver.rs index 06cd8ba9..37d17960 100644 --- a/src/admin/query/resolver.rs +++ b/src/admin/query/resolver.rs @@ -19,7 +19,7 @@ pub(super) async fn resolver(subcommand: Resolver) -> Result) -> Result { - use service::resolver::CachedDest; + use service::resolver::cache::CachedDest; let mut out = String::new(); writeln!(out, "| Server Name | Destination | Hostname | Expires |")?; @@ -36,7 +36,12 @@ async fn destinations_cache(server_name: Option) -> Result) -> Result) -> Result { - use service::resolver::CachedOverride; + use service::resolver::cache::CachedOverride; let mut out = String::new(); writeln!(out, "| Server Name | IP | Port | Expires |")?; @@ -65,7 +70,7 @@ async fn overrides_cache(server_name: Option) -> Result Result { + let cached; + let cached_result = self.get_cached_destination(server_name); + + let CachedDest { + dest, + host, + .. + } = if let Some(result) = cached_result { + cached = true; + result + } else { + cached = false; + validate_dest(server_name)?; + self.resolve_actual_dest(server_name, true).await? + }; + + let string = dest.clone().into_https_string(); + Ok(ActualDest { + dest, + host, + string, + cached, + }) + } + + /// Returns: `actual_destination`, host header + /// Implemented according to the specification at + /// Numbers in comments below refer to bullet points in linked section of + /// specification + #[tracing::instrument(skip_all, name = "actual")] + pub async fn resolve_actual_dest(&self, dest: &ServerName, cache: bool) -> Result { + trace!("Finding actual destination for {dest}"); + let mut host = dest.as_str().to_owned(); + let actual_dest = match get_ip_with_port(dest.as_str()) { + Some(host_port) => Self::actual_dest_1(host_port)?, + None => { + if let Some(pos) = dest.as_str().find(':') { + self.actual_dest_2(dest, cache, pos).await? + } else if let Some(delegated) = self.request_well_known(dest.as_str()).await? { + self.actual_dest_3(&mut host, cache, delegated).await? + } else if let Some(overrider) = self.query_srv_record(dest.as_str()).await? { + self.actual_dest_4(&host, cache, overrider).await? + } else { + self.actual_dest_5(dest, cache).await? + } + }, + }; + + // Can't use get_ip_with_port here because we don't want to add a port + // to an IP address if it wasn't specified + let host = if let Ok(addr) = host.parse::() { + FedDest::Literal(addr) + } else if let Ok(addr) = host.parse::() { + FedDest::Named(addr.to_string(), ":8448".to_owned()) + } else if let Some(pos) = host.find(':') { + let (host, port) = host.split_at(pos); + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + FedDest::Named(host, ":8448".to_owned()) + }; + + debug!("Actual destination: {actual_dest:?} hostname: {host:?}"); + Ok(CachedDest { + dest: actual_dest, + host: host.into_uri_string(), + expire: CachedDest::default_expire(), + }) + } + + fn actual_dest_1(host_port: FedDest) -> Result { + debug!("1: IP literal with provided or default port"); + Ok(host_port) + } + + async fn actual_dest_2(&self, dest: &ServerName, cache: bool, pos: usize) -> Result { + debug!("2: Hostname with included port"); + let (host, port) = dest.as_str().split_at(pos); + self.conditional_query_and_cache_override(host, host, port.parse::().unwrap_or(8448), cache) + .await?; + Ok(FedDest::Named(host.to_owned(), port.to_owned())) + } + + async fn actual_dest_3(&self, host: &mut String, cache: bool, delegated: String) -> Result { + debug!("3: A .well-known file is available"); + *host = add_port_to_hostname(&delegated).into_uri_string(); + match get_ip_with_port(&delegated) { + Some(host_and_port) => Self::actual_dest_3_1(host_and_port), + None => { + if let Some(pos) = delegated.find(':') { + self.actual_dest_3_2(cache, delegated, pos).await + } else { + trace!("Delegated hostname has no port in this branch"); + if let Some(overrider) = self.query_srv_record(&delegated).await? { + self.actual_dest_3_3(cache, delegated, overrider).await + } else { + self.actual_dest_3_4(cache, delegated).await + } + } + }, + } + } + + fn actual_dest_3_1(host_and_port: FedDest) -> Result { + debug!("3.1: IP literal in .well-known file"); + Ok(host_and_port) + } + + async fn actual_dest_3_2(&self, cache: bool, delegated: String, pos: usize) -> Result { + debug!("3.2: Hostname with port in .well-known file"); + let (host, port) = delegated.split_at(pos); + self.conditional_query_and_cache_override(host, host, port.parse::().unwrap_or(8448), cache) + .await?; + Ok(FedDest::Named(host.to_owned(), port.to_owned())) + } + + async fn actual_dest_3_3(&self, cache: bool, delegated: String, overrider: FedDest) -> Result { + debug!("3.3: SRV lookup successful"); + let force_port = overrider.port(); + self.conditional_query_and_cache_override(&delegated, &overrider.hostname(), force_port.unwrap_or(8448), cache) + .await?; + if let Some(port) = force_port { + Ok(FedDest::Named(delegated, format!(":{port}"))) + } else { + Ok(add_port_to_hostname(&delegated)) + } + } + + async fn actual_dest_3_4(&self, cache: bool, delegated: String) -> Result { + debug!("3.4: No SRV records, just use the hostname from .well-known"); + self.conditional_query_and_cache_override(&delegated, &delegated, 8448, cache) + .await?; + Ok(add_port_to_hostname(&delegated)) + } + + async fn actual_dest_4(&self, host: &str, cache: bool, overrider: FedDest) -> Result { + debug!("4: No .well-known; SRV record found"); + let force_port = overrider.port(); + self.conditional_query_and_cache_override(host, &overrider.hostname(), force_port.unwrap_or(8448), cache) + .await?; + if let Some(port) = force_port { + Ok(FedDest::Named(host.to_owned(), format!(":{port}"))) + } else { + Ok(add_port_to_hostname(host)) + } + } + + async fn actual_dest_5(&self, dest: &ServerName, cache: bool) -> Result { + debug!("5: No SRV record found"); + self.conditional_query_and_cache_override(dest.as_str(), dest.as_str(), 8448, cache) + .await?; + Ok(add_port_to_hostname(dest.as_str())) + } + + #[tracing::instrument(skip_all, name = "well-known")] + async fn request_well_known(&self, dest: &str) -> Result> { + trace!("Requesting well known for {dest}"); + if !self.has_cached_override(dest) { + self.query_and_cache_override(dest, dest, 8448).await?; + } + + let response = services() + .client + .well_known + .get(&format!("https://{dest}/.well-known/matrix/server")) + .send() + .await; + + trace!("response: {:?}", response); + if let Err(e) = &response { + debug!("error: {e:?}"); + return Ok(None); + } + + let response = response?; + if !response.status().is_success() { + debug!("response not 2XX"); + return Ok(None); + } + + let text = response.text().await?; + trace!("response text: {:?}", text); + if text.len() >= 12288 { + debug_warn!("response contains junk"); + return Ok(None); + } + + let body: serde_json::Value = serde_json::from_str(&text).unwrap_or_default(); + + let m_server = body + .get("m.server") + .unwrap_or(&serde_json::Value::Null) + .as_str() + .unwrap_or_default(); + + if ruma::identifiers_validation::server_name::validate(m_server).is_err() { + debug_error!("response content missing or invalid"); + return Ok(None); + } + + debug_info!("{:?} found at {:?}", dest, m_server); + Ok(Some(m_server.to_owned())) + } + + #[inline] + async fn conditional_query_and_cache_override( + &self, overname: &str, hostname: &str, port: u16, cache: bool, + ) -> Result<()> { + if cache { + self.query_and_cache_override(overname, hostname, port) + .await + } else { + Ok(()) + } + } + + #[tracing::instrument(skip_all, name = "ip")] + async fn query_and_cache_override(&self, overname: &'_ str, hostname: &'_ str, port: u16) -> Result<()> { + match services() + .resolver + .raw() + .lookup_ip(hostname.to_owned()) + .await + { + Err(e) => handle_resolve_error(&e), + Ok(override_ip) => { + if hostname != overname { + debug_info!("{overname:?} overriden by {hostname:?}"); + } + + services().resolver.set_cached_override( + overname.to_owned(), + CachedOverride { + ips: override_ip.iter().collect(), + port, + expire: CachedOverride::default_expire(), + }, + ); + + Ok(()) + }, + } + } + + #[tracing::instrument(skip_all, name = "srv")] + async fn query_srv_record(&self, hostname: &'_ str) -> Result> { + fn handle_successful_srv(srv: &SrvLookup) -> Option { + srv.iter().next().map(|result| { + FedDest::Named( + result.target().to_string().trim_end_matches('.').to_owned(), + format!(":{}", result.port()), + ) + }) + } + + async fn lookup_srv( + resolver: Arc, hostname: &str, + ) -> Result { + debug!("querying SRV for {hostname:?}"); + let hostname = hostname.trim_end_matches('.'); + resolver.srv_lookup(hostname.to_owned()).await + } + + let hostnames = [format!("_matrix-fed._tcp.{hostname}."), format!("_matrix._tcp.{hostname}.")]; + + for hostname in hostnames { + match lookup_srv(self.raw(), &hostname).await { + Ok(result) => return Ok(handle_successful_srv(&result)), + Err(e) => handle_resolve_error(&e)?, + } + } + + Ok(None) + } +} + +#[allow(clippy::single_match_else)] +fn handle_resolve_error(e: &ResolveError) -> Result<()> { + use hickory_resolver::error::ResolveErrorKind; + + match *e.kind() { + ResolveErrorKind::NoRecordsFound { + .. + } => { + // Raise to debug_warn if we can find out the result wasn't from cache + debug!("{e}"); + Ok(()) + }, + _ => Err!(error!("DNS {e}")), + } +} + +fn validate_dest(dest: &ServerName) -> Result<()> { + if dest == services().globals.server_name() { + return Err!("Won't send federation request to ourselves"); + } + + if dest.is_ip_literal() || IPAddress::is_valid(dest.host()) { + validate_dest_ip_literal(dest)?; + } + + Ok(()) +} + +fn validate_dest_ip_literal(dest: &ServerName) -> Result<()> { + trace!("Destination is an IP literal, checking against IP range denylist.",); + debug_assert!( + dest.is_ip_literal() || !IPAddress::is_valid(dest.host()), + "Destination is not an IP literal." + ); + let ip = IPAddress::parse(dest.host()).map_err(|e| { + debug_error!("Failed to parse IP literal from string: {}", e); + Error::BadServerResponse("Invalid IP address") + })?; + + validate_ip(&ip)?; + + Ok(()) +} + +pub(crate) fn validate_ip(ip: &IPAddress) -> Result<()> { + if !services().globals.valid_cidr_range(ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } + + Ok(()) +} diff --git a/src/service/resolver/cache.rs b/src/service/resolver/cache.rs new file mode 100644 index 00000000..0fba2400 --- /dev/null +++ b/src/service/resolver/cache.rs @@ -0,0 +1,114 @@ +use std::{ + collections::HashMap, + net::IpAddr, + sync::{Arc, RwLock}, + time::SystemTime, +}; + +use conduit::trace; +use ruma::{OwnedServerName, ServerName}; + +use super::fed::FedDest; +use crate::utils::rand; + +pub struct Cache { + pub destinations: RwLock, // actual_destination, host + pub overrides: RwLock, +} + +#[derive(Clone, Debug)] +pub struct CachedDest { + pub dest: FedDest, + pub host: String, + pub expire: SystemTime, +} + +#[derive(Clone, Debug)] +pub struct CachedOverride { + pub ips: Vec, + pub port: u16, + pub expire: SystemTime, +} + +pub type WellKnownMap = HashMap; +pub type TlsNameMap = HashMap; + +impl Cache { + pub(super) fn new() -> Arc { + Arc::new(Self { + destinations: RwLock::new(WellKnownMap::new()), + overrides: RwLock::new(TlsNameMap::new()), + }) + } +} + +impl super::Service { + pub fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option { + trace!(?name, ?dest, "set cached destination"); + self.cache + .destinations + .write() + .expect("locked for writing") + .insert(name, dest) + } + + #[must_use] + pub fn get_cached_destination(&self, name: &ServerName) -> Option { + self.cache + .destinations + .read() + .expect("locked for reading") + .get(name) + .cloned() + } + + pub fn set_cached_override(&self, name: String, over: CachedOverride) -> Option { + trace!(?name, ?over, "set cached override"); + self.cache + .overrides + .write() + .expect("locked for writing") + .insert(name, over) + } + + #[must_use] + pub fn get_cached_override(&self, name: &str) -> Option { + self.cache + .overrides + .read() + .expect("locked for reading") + .get(name) + .cloned() + } + + #[must_use] + pub fn has_cached_override(&self, name: &str) -> bool { + self.cache + .overrides + .read() + .expect("locked for reading") + .contains_key(name) + } +} + +impl CachedDest { + #[inline] + #[must_use] + pub fn valid(&self) -> bool { true } + + //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } + + #[must_use] + pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) } +} + +impl CachedOverride { + #[inline] + #[must_use] + pub fn valid(&self) -> bool { true } + + //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } + + #[must_use] + pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) } +} diff --git a/src/service/resolver/dns.rs b/src/service/resolver/dns.rs new file mode 100644 index 00000000..b77bbb84 --- /dev/null +++ b/src/service/resolver/dns.rs @@ -0,0 +1,129 @@ +use std::{ + future, iter, + net::{IpAddr, SocketAddr}, + sync::Arc, + time::Duration, +}; + +use conduit::{err, Result, Server}; +use hickory_resolver::TokioAsyncResolver; +use reqwest::dns::{Addrs, Name, Resolve, Resolving}; + +use super::cache::Cache; + +pub struct Resolver { + pub(crate) resolver: Arc, + pub(crate) hooked: Arc, +} + +pub(crate) struct Hooked { + resolver: Arc, + cache: Arc, +} + +impl Resolver { + #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] + pub(super) fn build(server: &Arc, cache: Arc) -> Result> { + let config = &server.config; + let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() + .map_err(|e| err!(error!("Failed to configure DNS resolver from system: {e}")))?; + + let mut conf = hickory_resolver::config::ResolverConfig::new(); + + if let Some(domain) = sys_conf.domain() { + conf.set_domain(domain.clone()); + } + + for sys_conf in sys_conf.search() { + conf.add_search(sys_conf.clone()); + } + + for sys_conf in sys_conf.name_servers() { + let mut ns = sys_conf.clone(); + + if config.query_over_tcp_only { + ns.protocol = hickory_resolver::config::Protocol::Tcp; + } + + ns.trust_negative_responses = !config.query_all_nameservers; + + conf.add_name_server(ns); + } + + opts.cache_size = config.dns_cache_entries as usize; + opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain)); + opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30)); + opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl)); + opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7)); + opts.timeout = Duration::from_secs(config.dns_timeout); + opts.attempts = config.dns_attempts as usize; + opts.try_tcp_on_error = config.dns_tcp_fallback; + opts.num_concurrent_reqs = 1; + opts.shuffle_dns_servers = true; + opts.rotate = true; + opts.ip_strategy = match config.ip_lookup_strategy { + 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only, + 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only, + 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6, + 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4, + _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6, + }; + opts.authentic_data = false; + + let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); + Ok(Arc::new(Self { + resolver: resolver.clone(), + hooked: Arc::new(Hooked { + resolver, + cache, + }), + })) + } +} + +impl Resolve for Resolver { + fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } +} + +impl Resolve for Hooked { + fn resolve(&self, name: Name) -> Resolving { + let cached = self + .cache + .overrides + .read() + .expect("locked for reading") + .get(name.as_str()) + .cloned(); + + if let Some(cached) = cached { + cached_to_reqwest(&cached.ips, cached.port) + } else { + resolve_to_reqwest(self.resolver.clone(), name) + } + } +} + +fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { + override_name + .first() + .map(|first_name| -> Resolving { + let saddr = SocketAddr::new(*first_name, port); + let result: Box + Send> = Box::new(iter::once(saddr)); + Box::pin(future::ready(Ok(result))) + }) + .expect("must provide at least one override name") +} + +fn resolve_to_reqwest(resolver: Arc, name: Name) -> Resolving { + Box::pin(async move { + let results = resolver + .lookup_ip(name.as_str()) + .await? + .into_iter() + .map(|ip| SocketAddr::new(ip, 0)); + + let results: Addrs = Box::new(results); + + Ok(results) + }) +} diff --git a/src/service/resolver/fed.rs b/src/service/resolver/fed.rs new file mode 100644 index 00000000..10cbbbdd --- /dev/null +++ b/src/service/resolver/fed.rs @@ -0,0 +1,70 @@ +use std::{ + fmt, + net::{IpAddr, SocketAddr}, +}; + +#[derive(Clone, Debug, PartialEq, Eq)] +pub enum FedDest { + Literal(SocketAddr), + Named(String, String), +} + +pub(crate) fn get_ip_with_port(dest_str: &str) -> Option { + if let Ok(dest) = dest_str.parse::() { + Some(FedDest::Literal(dest)) + } else if let Ok(ip_addr) = dest_str.parse::() { + Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) + } else { + None + } +} + +pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest { + let (host, port) = match dest_str.find(':') { + None => (dest_str, ":8448"), + Some(pos) => dest_str.split_at(pos), + }; + + FedDest::Named(host.to_owned(), port.to_owned()) +} + +impl FedDest { + pub(crate) fn into_https_string(self) -> String { + match self { + Self::Literal(addr) => format!("https://{addr}"), + Self::Named(host, port) => format!("https://{host}{port}"), + } + } + + pub(crate) fn into_uri_string(self) -> String { + match self { + Self::Literal(addr) => addr.to_string(), + Self::Named(host, port) => format!("{host}{port}"), + } + } + + pub(crate) fn hostname(&self) -> String { + match &self { + Self::Literal(addr) => addr.ip().to_string(), + Self::Named(host, _) => host.clone(), + } + } + + #[inline] + #[allow(clippy::string_slice)] + pub(crate) fn port(&self) -> Option { + match &self { + Self::Literal(addr) => Some(addr.port()), + Self::Named(_, port) => port[1..].parse().ok(), + } + } +} + +impl fmt::Display for FedDest { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + Self::Named(host, port) => write!(f, "{host}{port}"), + Self::Literal(addr) => write!(f, "{addr}"), + } + } +} diff --git a/src/service/resolver/mod.rs b/src/service/resolver/mod.rs index 62fd1625..48ff8813 100644 --- a/src/service/resolver/mod.rs +++ b/src/service/resolver/mod.rs @@ -1,349 +1,66 @@ -use std::{ - collections::HashMap, - fmt, - fmt::Write, - future, iter, - net::{IpAddr, SocketAddr}, - sync::{Arc, RwLock}, - time::{Duration, SystemTime}, -}; +pub mod actual; +pub mod cache; +mod dns; +pub mod fed; +mod tests; -use conduit::{err, trace, Result}; +use std::{fmt::Write, sync::Arc}; + +use conduit::Result; use hickory_resolver::TokioAsyncResolver; -use reqwest::dns::{Addrs, Name, Resolve, Resolving}; -use ruma::{OwnedServerName, ServerName}; -use crate::utils::rand; +use self::{cache::Cache, dns::Resolver}; pub struct Service { - pub destinations: Arc>, // actual_destination, host - pub overrides: Arc>, - pub(crate) resolver: Arc, - pub(crate) hooked: Arc, + pub cache: Arc, + pub resolver: Arc, } -pub(crate) struct Hooked { - overrides: Arc>, - resolver: Arc, -} - -#[derive(Clone, Debug)] -pub struct CachedDest { - pub dest: FedDest, - pub host: String, - pub expire: SystemTime, -} - -#[derive(Clone, Debug)] -pub struct CachedOverride { - pub ips: Vec, - pub port: u16, - pub expire: SystemTime, -} - -#[derive(Clone, Debug, PartialEq, Eq)] -pub enum FedDest { - Literal(SocketAddr), - Named(String, String), -} - -type WellKnownMap = HashMap; -type TlsNameMap = HashMap; - impl crate::Service for Service { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] fn build(args: crate::Args<'_>) -> Result> { - let config = &args.server.config; - let (sys_conf, mut opts) = hickory_resolver::system_conf::read_system_conf() - .map_err(|e| err!(error!("Failed to configure DNS resolver from system: {e}")))?; - - let mut conf = hickory_resolver::config::ResolverConfig::new(); - - if let Some(domain) = sys_conf.domain() { - conf.set_domain(domain.clone()); - } - - for sys_conf in sys_conf.search() { - conf.add_search(sys_conf.clone()); - } - - for sys_conf in sys_conf.name_servers() { - let mut ns = sys_conf.clone(); - - if config.query_over_tcp_only { - ns.protocol = hickory_resolver::config::Protocol::Tcp; - } - - ns.trust_negative_responses = !config.query_all_nameservers; - - conf.add_name_server(ns); - } - - opts.cache_size = config.dns_cache_entries as usize; - opts.negative_min_ttl = Some(Duration::from_secs(config.dns_min_ttl_nxdomain)); - opts.negative_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 30)); - opts.positive_min_ttl = Some(Duration::from_secs(config.dns_min_ttl)); - opts.positive_max_ttl = Some(Duration::from_secs(60 * 60 * 24 * 7)); - opts.timeout = Duration::from_secs(config.dns_timeout); - opts.attempts = config.dns_attempts as usize; - opts.try_tcp_on_error = config.dns_tcp_fallback; - opts.num_concurrent_reqs = 1; - opts.shuffle_dns_servers = true; - opts.rotate = true; - opts.ip_strategy = match config.ip_lookup_strategy { - 1 => hickory_resolver::config::LookupIpStrategy::Ipv4Only, - 2 => hickory_resolver::config::LookupIpStrategy::Ipv6Only, - 3 => hickory_resolver::config::LookupIpStrategy::Ipv4AndIpv6, - 4 => hickory_resolver::config::LookupIpStrategy::Ipv6thenIpv4, - _ => hickory_resolver::config::LookupIpStrategy::Ipv4thenIpv6, - }; - opts.authentic_data = false; - - let resolver = Arc::new(TokioAsyncResolver::tokio(conf, opts)); - let overrides = Arc::new(RwLock::new(TlsNameMap::new())); + let cache = Cache::new(); Ok(Arc::new(Self { - destinations: Arc::new(RwLock::new(WellKnownMap::new())), - overrides: overrides.clone(), - resolver: resolver.clone(), - hooked: Arc::new(Hooked { - overrides, - resolver, - }), + cache: cache.clone(), + resolver: Resolver::build(args.server, cache)?, })) } fn memory_usage(&self, out: &mut dyn Write) -> Result<()> { - let resolver_overrides_cache = self.overrides.read().expect("locked for reading").len(); + let resolver_overrides_cache = self + .cache + .overrides + .read() + .expect("locked for reading") + .len(); writeln!(out, "resolver_overrides_cache: {resolver_overrides_cache}")?; - let resolver_destinations_cache = self.destinations.read().expect("locked for reading").len(); + let resolver_destinations_cache = self + .cache + .destinations + .read() + .expect("locked for reading") + .len(); writeln!(out, "resolver_destinations_cache: {resolver_destinations_cache}")?; Ok(()) } fn clear_cache(&self) { - self.overrides.write().expect("write locked").clear(); - self.destinations.write().expect("write locked").clear(); - self.resolver.clear_cache(); + self.cache.overrides.write().expect("write locked").clear(); + self.cache + .destinations + .write() + .expect("write locked") + .clear(); + self.resolver.resolver.clear_cache(); } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } impl Service { - pub fn set_cached_destination(&self, name: OwnedServerName, dest: CachedDest) -> Option { - trace!(?name, ?dest, "set cached destination"); - self.destinations - .write() - .expect("locked for writing") - .insert(name, dest) - } - - #[must_use] - pub fn get_cached_destination(&self, name: &ServerName) -> Option { - self.destinations - .read() - .expect("locked for reading") - .get(name) - .cloned() - } - - pub fn set_cached_override(&self, name: String, over: CachedOverride) -> Option { - trace!(?name, ?over, "set cached override"); - self.overrides - .write() - .expect("locked for writing") - .insert(name, over) - } - - #[must_use] - pub fn get_cached_override(&self, name: &str) -> Option { - self.overrides - .read() - .expect("locked for reading") - .get(name) - .cloned() - } - - #[must_use] - pub fn has_cached_override(&self, name: &str) -> bool { - self.overrides - .read() - .expect("locked for reading") - .contains_key(name) - } -} - -impl Resolve for Service { - fn resolve(&self, name: Name) -> Resolving { resolve_to_reqwest(self.resolver.clone(), name) } -} - -impl Resolve for Hooked { - fn resolve(&self, name: Name) -> Resolving { - let cached = self - .overrides - .read() - .expect("locked for reading") - .get(name.as_str()) - .cloned(); - - if let Some(cached) = cached { - cached_to_reqwest(&cached.ips, cached.port) - } else { - resolve_to_reqwest(self.resolver.clone(), name) - } - } -} - -fn cached_to_reqwest(override_name: &[IpAddr], port: u16) -> Resolving { - override_name - .first() - .map(|first_name| -> Resolving { - let saddr = SocketAddr::new(*first_name, port); - let result: Box + Send> = Box::new(iter::once(saddr)); - Box::pin(future::ready(Ok(result))) - }) - .expect("must provide at least one override name") -} - -fn resolve_to_reqwest(resolver: Arc, name: Name) -> Resolving { - Box::pin(async move { - let results = resolver - .lookup_ip(name.as_str()) - .await? - .into_iter() - .map(|ip| SocketAddr::new(ip, 0)); - - let results: Addrs = Box::new(results); - - Ok(results) - }) -} - -impl CachedDest { #[inline] #[must_use] - pub fn valid(&self) -> bool { true } - - //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } - - #[must_use] - pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 18..60 * 60 * 36) } -} - -impl CachedOverride { - #[inline] - #[must_use] - pub fn valid(&self) -> bool { true } - - //pub fn valid(&self) -> bool { self.expire > SystemTime::now() } - - #[must_use] - pub(crate) fn default_expire() -> SystemTime { rand::timepoint_secs(60 * 60 * 6..60 * 60 * 12) } -} - -pub(crate) fn get_ip_with_port(dest_str: &str) -> Option { - if let Ok(dest) = dest_str.parse::() { - Some(FedDest::Literal(dest)) - } else if let Ok(ip_addr) = dest_str.parse::() { - Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) - } else { - None - } -} - -pub(crate) fn add_port_to_hostname(dest_str: &str) -> FedDest { - let (host, port) = match dest_str.find(':') { - None => (dest_str, ":8448"), - Some(pos) => dest_str.split_at(pos), - }; - - FedDest::Named(host.to_owned(), port.to_owned()) -} - -impl FedDest { - pub(crate) fn into_https_string(self) -> String { - match self { - Self::Literal(addr) => format!("https://{addr}"), - Self::Named(host, port) => format!("https://{host}{port}"), - } - } - - pub(crate) fn into_uri_string(self) -> String { - match self { - Self::Literal(addr) => addr.to_string(), - Self::Named(host, port) => format!("{host}{port}"), - } - } - - pub(crate) fn hostname(&self) -> String { - match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), - } - } - - #[inline] - #[allow(clippy::string_slice)] - pub(crate) fn port(&self) -> Option { - match &self { - Self::Literal(addr) => Some(addr.port()), - Self::Named(_, port) => port[1..].parse().ok(), - } - } -} - -impl fmt::Display for FedDest { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - match self { - Self::Named(host, port) => write!(f, "{host}{port}"), - Self::Literal(addr) => write!(f, "{addr}"), - } - } -} - -#[cfg(test)] -mod tests { - use super::{add_port_to_hostname, get_ip_with_port, FedDest}; - - #[test] - fn ips_get_default_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1"), - Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("dead:beef::"), - Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) - ); - } - - #[test] - fn ips_keep_custom_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1:1234"), - Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("[dead::beef]:8933"), - Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) - ); - } - - #[test] - fn hostnames_get_default_ports() { - assert_eq!( - add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) - ); - } - - #[test] - fn hostnames_keep_custom_ports() { - assert_eq!( - add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) - ); - } + pub fn raw(&self) -> Arc { self.resolver.resolver.clone() } } diff --git a/src/service/resolver/tests.rs b/src/service/resolver/tests.rs new file mode 100644 index 00000000..55cf0345 --- /dev/null +++ b/src/service/resolver/tests.rs @@ -0,0 +1,43 @@ +#![cfg(test)] + +use super::fed::{add_port_to_hostname, get_ip_with_port, FedDest}; + +#[test] +fn ips_get_default_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1"), + Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("dead:beef::"), + Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) + ); +} + +#[test] +fn ips_keep_custom_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1:1234"), + Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("[dead::beef]:8933"), + Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) + ); +} + +#[test] +fn hostnames_get_default_ports() { + assert_eq!( + add_port_to_hostname("example.com"), + FedDest::Named(String::from("example.com"), String::from(":8448")) + ); +} + +#[test] +fn hostnames_keep_custom_ports() { + assert_eq!( + add_port_to_hostname("example.com:1337"), + FedDest::Named(String::from("example.com"), String::from(":1337")) + ); +} diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index be10184d..1eacca77 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -1,13 +1,11 @@ mod appservice; mod data; -mod resolve; mod send; mod sender; use std::fmt::Debug; use conduit::{err, Result}; -pub use resolve::resolve_actual_dest; use ruma::{ api::{appservice::Registration, OutgoingRequest}, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index dc4541d6..df3139c3 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -15,8 +15,11 @@ use ruma::{ }; use tracing::{debug, trace}; -use super::{resolve, resolve::ActualDest}; -use crate::{debug_error, debug_warn, resolver::CachedDest, services, Error, Result}; +use crate::{ + debug_error, debug_warn, resolver, + resolver::{actual::ActualDest, cache::CachedDest}, + services, Error, Result, +}; #[tracing::instrument(skip_all, name = "send")] pub async fn send(client: &Client, dest: &ServerName, req: T) -> Result @@ -27,7 +30,7 @@ where return Err!(Config("allow_federation", "Federation is disabled.")); } - let actual = resolve::get_actual_dest(dest).await?; + let actual = services().resolver.get_actual_dest(dest).await?; let request = prepare::(dest, &actual, req).await?; execute::(client, dest, &actual, request).await } @@ -219,7 +222,7 @@ fn validate_url(url: &Url) -> Result<()> { if let Some(url_host) = url.host_str() { if let Ok(ip) = IPAddress::parse(url_host) { trace!("Checking request URL IP {ip:?}"); - resolve::validate_ip(&ip)?; + resolver::actual::validate_ip(&ip)?; } }