diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 00eaaa5f..2fe5d538 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -159,7 +159,7 @@ where let mut write_destination_to_cache = false; - let cached_result = services().globals.actual_destination_cache.read().await.get(destination).cloned(); + let cached_result = services().globals.actual_destinations().read().await.get(destination).cloned(); let (actual_destination, host) = if let Some(result) = cached_result { result @@ -313,7 +313,7 @@ where if response.is_ok() && write_destination_to_cache { services() .globals - .actual_destination_cache + .actual_destinations() .write() .await .insert(OwnedServerName::from(destination), (actual_destination, host)); @@ -496,7 +496,8 @@ async fn query_and_cache_override(overname: &'_ str, hostname: &'_ str, port: u1 services() .globals - .tls_name_override + .resolver + .overrides .write() .unwrap() .insert(overname.to_owned(), (override_ip.iter().collect(), port)); @@ -538,7 +539,7 @@ async fn query_srv_record(hostname: &'_ str) -> Option { } async fn request_well_known(destination: &str) -> Option { - if !services().globals.tls_name_override.read().unwrap().contains_key(destination) { + if !services().globals.resolver.overrides.read().unwrap().contains_key(destination) { query_and_cache_override(destination, destination, 8448).await; } diff --git a/src/service/globals/client.rs b/src/service/globals/client.rs new file mode 100644 index 00000000..0d05efa1 --- /dev/null +++ b/src/service/globals/client.rs @@ -0,0 +1,90 @@ +use std::{sync::Arc, time::Duration}; + +use reqwest::redirect; + +use crate::{service::globals::resolver, Config, Result}; + +pub struct Client { + pub default: reqwest::Client, + pub url_preview: reqwest::Client, + pub well_known: reqwest::Client, + pub federation: reqwest::Client, + pub sender: reqwest::Client, + pub appservice: reqwest::Client, + pub pusher: reqwest::Client, +} + +impl Client { + pub fn new(config: &Config, resolver: &Arc) -> Client { + Client { + default: Self::base(config).unwrap().build().unwrap(), + + url_preview: Self::base(config).unwrap().redirect(redirect::Policy::limited(3)).build().unwrap(), + + well_known: Self::base(config) + .unwrap() + .dns_resolver(resolver.clone()) + .connect_timeout(Duration::from_secs(config.well_known_conn_timeout)) + .timeout(Duration::from_secs(config.well_known_timeout)) + .pool_max_idle_per_host(0) + .redirect(redirect::Policy::limited(4)) + .build() + .unwrap(), + + federation: Self::base(config) + .unwrap() + .dns_resolver(resolver.clone()) + .timeout(Duration::from_secs(config.federation_timeout)) + .pool_max_idle_per_host(config.federation_idle_per_host.into()) + .pool_idle_timeout(Duration::from_secs(config.federation_idle_timeout)) + .redirect(redirect::Policy::limited(3)) + .build() + .unwrap(), + + sender: Self::base(config) + .unwrap() + .dns_resolver(resolver.clone()) + .timeout(Duration::from_secs(config.sender_timeout)) + .pool_max_idle_per_host(1) + .pool_idle_timeout(Duration::from_secs(config.sender_idle_timeout)) + .redirect(redirect::Policy::limited(2)) + .build() + .unwrap(), + + appservice: Self::base(config) + .unwrap() + .connect_timeout(Duration::from_secs(5)) + .timeout(Duration::from_secs(config.appservice_timeout)) + .pool_max_idle_per_host(1) + .pool_idle_timeout(Duration::from_secs(config.appservice_idle_timeout)) + .redirect(redirect::Policy::limited(2)) + .build() + .unwrap(), + + pusher: Self::base(config) + .unwrap() + .pool_max_idle_per_host(1) + .pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout)) + .redirect(redirect::Policy::limited(2)) + .build() + .unwrap(), + } + } + + fn base(config: &Config) -> Result { + let builder = reqwest::Client::builder() + .hickory_dns(true) + .timeout(Duration::from_secs(config.request_timeout)) + .connect_timeout(Duration::from_secs(config.request_conn_timeout)) + .pool_max_idle_per_host(config.request_idle_per_host.into()) + .pool_idle_timeout(Duration::from_secs(config.request_idle_timeout)) + .user_agent("Conduwuit".to_owned() + "/" + env!("CARGO_PKG_VERSION")) + .redirect(redirect::Policy::limited(6)); + + if let Some(proxy) = config.proxy.to_proxy()? { + Ok(builder.proxy(proxy)) + } else { + Ok(builder) + } + } +} diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 348a2b11..61a00e15 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,32 +1,20 @@ use std::{ collections::{BTreeMap, HashMap}, - error::Error as StdError, fs, - future::{self, Future}, - iter, - net::{IpAddr, SocketAddr}, + future::Future, path::PathBuf, sync::{ atomic::{self, AtomicBool}, - Arc, RwLock as StdRwLock, + Arc, }, - time::{Duration, Instant}, + time::Instant, }; use argon2::Argon2; use base64::{engine::general_purpose, Engine as _}; pub use data::Data; -use futures_util::FutureExt; use hickory_resolver::TokioAsyncResolver; -use hyper::{ - client::connect::dns::{GaiResolver, Name}, - service::Service as HyperService, -}; use regex::RegexSet; -use reqwest::{ - dns::{Addrs, Resolve, Resolving}, - redirect, -}; use ruma::{ api::{ client::sync::sync_events, @@ -39,12 +27,12 @@ use ruma::{ use tokio::sync::{broadcast, watch::Receiver, Mutex, RwLock, Semaphore}; use tracing::{error, info}; -use crate::{api::server_server::FedDest, services, Config, Error, Result}; +use crate::{services, Config, Result}; +pub mod client; mod data; +pub mod resolver; -type WellKnownMap = HashMap; -type TlsNameMap = HashMap, u16)>; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries type SyncHandle = ( Option, // since @@ -54,13 +42,11 @@ type SyncHandle = ( pub struct Service<'a> { pub db: &'static dyn Data, - pub actual_destination_cache: Arc>, // actual_destination, host - pub tls_name_override: Arc>, pub config: Config, keypair: Arc, - dns_resolver: TokioAsyncResolver, jwt_decoding_key: Option, - pub client: Client, + pub resolver: Arc, + pub client: client::Client, pub stable_room_versions: Vec, pub unstable_room_versions: Vec, pub bad_event_ratelimiter: Arc>>, @@ -79,16 +65,6 @@ pub struct Service<'a> { pub argon: Argon2<'a>, } -pub struct Client { - pub default: reqwest::Client, - pub url_preview: reqwest::Client, - pub well_known: reqwest::Client, - pub federation: reqwest::Client, - pub sender: reqwest::Client, - pub appservice: reqwest::Client, - pub pusher: reqwest::Client, -} - /// Handles "rotation" of long-polling requests. "Rotation" in this context is /// similar to "rotation" of log files and the like. /// @@ -117,121 +93,6 @@ impl Default for RotationHandler { fn default() -> Self { Self::new() } } -struct Resolver { - inner: GaiResolver, - overrides: Arc>, -} - -impl Resolver { - fn new(overrides: Arc>) -> Self { - Resolver { - inner: GaiResolver::new(), - overrides, - } - } -} - -impl Resolve for Resolver { - fn resolve(&self, name: Name) -> Resolving { - self.overrides - .read() - .unwrap() - .get(name.as_str()) - .and_then(|(override_name, port)| { - override_name.first().map(|first_name| { - let x: Box + Send> = - Box::new(iter::once(SocketAddr::new(*first_name, *port))); - let x: Resolving = Box::pin(future::ready(Ok(x))); - x - }) - }) - .unwrap_or_else(|| { - let this = &mut self.inner.clone(); - Box::pin(HyperService::::call(this, name).map(|result| { - result - .map(|addrs| -> Addrs { Box::new(addrs) }) - .map_err(|err| -> Box { Box::new(err) }) - })) - }) - } -} - -impl Client { - pub fn new(config: &Config, tls_name_override: &Arc>) -> Client { - let resolver = Arc::new(Resolver::new(tls_name_override.clone())); - Client { - default: Self::base(config).unwrap().build().unwrap(), - - url_preview: Self::base(config).unwrap().redirect(redirect::Policy::limited(3)).build().unwrap(), - - well_known: Self::base(config) - .unwrap() - .dns_resolver(resolver.clone()) - .connect_timeout(Duration::from_secs(config.well_known_conn_timeout)) - .timeout(Duration::from_secs(config.well_known_timeout)) - .pool_max_idle_per_host(0) - .redirect(redirect::Policy::limited(4)) - .build() - .unwrap(), - - federation: Self::base(config) - .unwrap() - .dns_resolver(resolver.clone()) - .timeout(Duration::from_secs(config.federation_timeout)) - .pool_max_idle_per_host(config.federation_idle_per_host.into()) - .pool_idle_timeout(Duration::from_secs(config.federation_idle_timeout)) - .redirect(redirect::Policy::limited(3)) - .build() - .unwrap(), - - sender: Self::base(config) - .unwrap() - .dns_resolver(resolver) - .timeout(Duration::from_secs(config.sender_timeout)) - .pool_max_idle_per_host(1) - .pool_idle_timeout(Duration::from_secs(config.sender_idle_timeout)) - .redirect(redirect::Policy::limited(2)) - .build() - .unwrap(), - - appservice: Self::base(config) - .unwrap() - .connect_timeout(Duration::from_secs(5)) - .timeout(Duration::from_secs(config.appservice_timeout)) - .pool_max_idle_per_host(1) - .pool_idle_timeout(Duration::from_secs(config.appservice_idle_timeout)) - .redirect(redirect::Policy::limited(2)) - .build() - .unwrap(), - - pusher: Self::base(config) - .unwrap() - .pool_max_idle_per_host(1) - .pool_idle_timeout(Duration::from_secs(config.pusher_idle_timeout)) - .redirect(redirect::Policy::limited(2)) - .build() - .unwrap(), - } - } - - fn base(config: &Config) -> Result { - let builder = reqwest::Client::builder() - .hickory_dns(true) - .timeout(Duration::from_secs(config.request_timeout)) - .connect_timeout(Duration::from_secs(config.request_conn_timeout)) - .pool_max_idle_per_host(config.request_idle_per_host.into()) - .pool_idle_timeout(Duration::from_secs(config.request_idle_timeout)) - .user_agent("Conduwuit".to_owned() + "/" + env!("CARGO_PKG_VERSION")) - .redirect(redirect::Policy::limited(6)); - - if let Some(proxy) = config.proxy.to_proxy()? { - Ok(builder.proxy(proxy)) - } else { - Ok(builder) - } - } -} - impl Service<'_> { pub fn load(db: &'static dyn Data, config: &Config) -> Result { let keypair = db.load_keypair(); @@ -245,11 +106,11 @@ impl Service<'_> { }, }; - let tls_name_override = Arc::new(StdRwLock::new(TlsNameMap::new())); - let jwt_decoding_key = config.jwt_secret.as_ref().map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); + let resolver = Arc::new(resolver::Resolver::new(config)); + // Supported and stable room versions let stable_room_versions = vec![ RoomVersionId::V6, @@ -276,13 +137,8 @@ impl Service<'_> { db, config: config.clone(), keypair: Arc::new(keypair), - dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| { - error!("Failed to set up trust dns resolver with system config: {}", e); - Error::bad_config("Failed to set up trust dns resolver with system config.") - })?, - actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())), - tls_name_override: tls_name_override.clone(), - client: Client::new(config, &tls_name_override), + resolver: resolver.clone(), + client: client::Client::new(config, &resolver), jwt_decoding_key, stable_room_versions, unstable_room_versions, @@ -372,7 +228,9 @@ impl Service<'_> { pub fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first } - pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.dns_resolver } + pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.resolver.resolver } + + pub fn actual_destinations(&self) -> &Arc> { &self.resolver.destinations } pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } diff --git a/src/service/globals/resolver.rs b/src/service/globals/resolver.rs new file mode 100644 index 00000000..a8f90894 --- /dev/null +++ b/src/service/globals/resolver.rs @@ -0,0 +1,72 @@ +use std::{ + collections::HashMap, + error::Error as StdError, + future::{self}, + iter, + net::{IpAddr, SocketAddr}, + sync::{Arc, RwLock as StdRwLock}, +}; + +use futures_util::FutureExt; +use hickory_resolver::TokioAsyncResolver; +use hyper::{ + client::connect::dns::{GaiResolver, Name}, + service::Service as HyperService, +}; +use reqwest::dns::{Addrs, Resolve, Resolving}; +use ruma::OwnedServerName; +use tokio::sync::RwLock; +use tracing::error; + +use crate::{api::server_server::FedDest, Config, Error}; + +pub type WellKnownMap = HashMap; +pub type TlsNameMap = HashMap, u16)>; + +pub struct Resolver { + inner: GaiResolver, + pub overrides: Arc>, + pub destinations: Arc>, // actual_destination, host + pub resolver: TokioAsyncResolver, +} + +impl Resolver { + pub(crate) fn new(_config: &Config) -> Self { + Resolver { + inner: GaiResolver::new(), + overrides: Arc::new(StdRwLock::new(TlsNameMap::new())), + destinations: Arc::new(RwLock::new(WellKnownMap::new())), + resolver: TokioAsyncResolver::tokio_from_system_conf() + .map_err(|e| { + error!("Failed to set up trust dns resolver with system config: {}", e); + Error::bad_config("Failed to set up trust dns resolver with system config.") + }) + .unwrap(), + } + } +} + +impl Resolve for Resolver { + fn resolve(&self, name: Name) -> Resolving { + self.overrides + .read() + .unwrap() + .get(name.as_str()) + .and_then(|(override_name, port)| { + override_name.first().map(|first_name| { + let x: Box + Send> = + Box::new(iter::once(SocketAddr::new(*first_name, *port))); + let x: Resolving = Box::pin(future::ready(Ok(x))); + x + }) + }) + .unwrap_or_else(|| { + let this = &mut self.inner.clone(); + Box::pin(HyperService::::call(this, name).map(|result| { + result + .map(|addrs| -> Addrs { Box::new(addrs) }) + .map_err(|err| -> Box { Box::new(err) }) + })) + }) + } +} diff --git a/src/service/mod.rs b/src/service/mod.rs index 3d5e9ff4..8f19e029 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -205,10 +205,10 @@ roomid_spacehierarchy_cache: {roomid_spacehierarchy_cache}" self.rooms.spaces.roomid_spacehierarchy_cache.lock().await.clear(); } if amount > 6 { - self.globals.tls_name_override.write().unwrap().clear(); + self.globals.resolver.overrides.write().unwrap().clear(); } if amount > 7 { - self.globals.dns_resolver().clear_cache(); + self.globals.resolver.resolver.clear_cache(); } } }