diff --git a/src/admin/query/resolver.rs b/src/admin/query/resolver.rs index b53661fc..0b6da6fd 100644 --- a/src/admin/query/resolver.rs +++ b/src/admin/query/resolver.rs @@ -1,7 +1,6 @@ -use std::fmt::Write; - use clap::Subcommand; use conduwuit::{utils::time, Result}; +use futures::StreamExt; use ruma::{events::room::message::RoomMessageEventContent, OwnedServerName}; use crate::{admin_command, admin_command_dispatch}; @@ -31,29 +30,19 @@ async fn destinations_cache( writeln!(self, "| Server Name | Destination | Hostname | Expires |").await?; writeln!(self, "| ----------- | ----------- | -------- | ------- |").await?; - let mut out = String::new(); - { - let map = self - .services - .resolver - .cache - .destinations - .read() - .expect("locked"); + let mut destinations = self.services.resolver.cache.destinations().boxed(); - for (name, &CachedDest { ref dest, ref host, expire }) in map.iter() { - if let Some(server_name) = server_name.as_ref() { - if name != server_name { - continue; - } + while let Some((name, CachedDest { dest, host, expire })) = destinations.next().await { + if let Some(server_name) = server_name.as_ref() { + if name != server_name { + continue; } - - let expire = time::format(expire, "%+"); - writeln!(out, "| {name} | {dest} | {host} | {expire} |")?; } - } - self.write_str(out.as_str()).await?; + let expire = time::format(expire, "%+"); + self.write_str(&format!("| {name} | {dest} | {host} | {expire} |\n")) + .await?; + } Ok(RoomMessageEventContent::notice_plain("")) } @@ -65,29 +54,19 @@ async fn overrides_cache(&self, server_name: Option) -> Result Result { - let (result, cached) = if let Some(result) = self.get_cached_destination(server_name) { + let (result, cached) = if let Ok(result) = self.cache.get_destination(server_name).await { (result, true) } else { self.validate_dest(server_name)?; @@ -232,7 +232,7 @@ impl super::Service { #[tracing::instrument(skip_all, name = "well-known")] async fn request_well_known(&self, dest: &str) -> Result> { - if !self.has_cached_override(dest) { + if !self.cache.has_override(dest).await { self.query_and_cache_override(dest, dest, 8448).await?; } @@ -315,7 +315,7 @@ impl super::Service { debug_info!("{overname:?} overriden by {hostname:?}"); } - self.set_cached_override(overname, CachedOverride { + self.cache.set_override(overname, CachedOverride { ips: override_ip.into_iter().take(MAX_IPS).collect(), port, expire: CachedOverride::default_expire(), diff --git a/src/service/resolver/cache.rs b/src/service/resolver/cache.rs index e309a129..11e6c9bd 100644 --- a/src/service/resolver/cache.rs +++ b/src/service/resolver/cache.rs @@ -1,108 +1,103 @@ -use std::{ - collections::HashMap, - net::IpAddr, - sync::{Arc, RwLock}, - time::SystemTime, -}; +use std::{net::IpAddr, sync::Arc, time::SystemTime}; use arrayvec::ArrayVec; use conduwuit::{ - trace, - utils::{math::Expected, rand}, + at, implement, + utils::{math::Expected, rand, stream::TryIgnore}, + Result, }; -use ruma::{OwnedServerName, ServerName}; +use database::{Cbor, Deserialized, Map}; +use futures::{Stream, StreamExt}; +use ruma::ServerName; +use serde::{Deserialize, Serialize}; use super::fed::FedDest; pub struct Cache { - pub destinations: RwLock, // actual_destination, host - pub overrides: RwLock, + destinations: Arc, + overrides: Arc, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct CachedDest { pub dest: FedDest, pub host: String, pub expire: SystemTime, } -#[derive(Clone, Debug)] +#[derive(Clone, Debug, Deserialize, Serialize)] pub struct CachedOverride { pub ips: IpAddrs, pub port: u16, pub expire: SystemTime, } -pub type WellKnownMap = HashMap; -pub type TlsNameMap = HashMap; - pub type IpAddrs = ArrayVec; pub(crate) const MAX_IPS: usize = 3; impl Cache { - pub(super) fn new() -> Arc { + pub(super) fn new(args: &crate::Args<'_>) -> Arc { Arc::new(Self { - destinations: RwLock::new(WellKnownMap::new()), - overrides: RwLock::new(TlsNameMap::new()), + destinations: args.db["servername_destination"].clone(), + overrides: args.db["servername_override"].clone(), }) } } -impl super::Service { - pub fn set_cached_destination( - &self, - name: OwnedServerName, - dest: CachedDest, - ) -> Option { - trace!(?name, ?dest, "set cached destination"); - self.cache - .destinations - .write() - .expect("locked for writing") - .insert(name, dest) - } +#[implement(Cache)] +pub fn set_destination(&self, name: &ServerName, dest: CachedDest) { + self.destinations.raw_put(name, Cbor(dest)); +} - #[must_use] - pub fn get_cached_destination(&self, name: &ServerName) -> Option { - self.cache - .destinations - .read() - .expect("locked for reading") - .get(name) - .cloned() - } +#[implement(Cache)] +pub fn set_override(&self, name: &str, over: CachedOverride) { + self.overrides.raw_put(name, Cbor(over)); +} - pub fn set_cached_override( - &self, - name: &str, - over: CachedOverride, - ) -> Option { - trace!(?name, ?over, "set cached override"); - self.cache - .overrides - .write() - .expect("locked for writing") - .insert(name.into(), over) - } +#[implement(Cache)] +pub async fn get_destination(&self, name: &ServerName) -> Result { + self.destinations + .get(name) + .await + .deserialized::>() + .map(at!(0)) +} - #[must_use] - pub fn get_cached_override(&self, name: &str) -> Option { - self.cache - .overrides - .read() - .expect("locked for reading") - .get(name) - .cloned() - } +#[implement(Cache)] +pub async fn get_override(&self, name: &str) -> Result { + self.overrides + .get(name) + .await + .deserialized::>() + .map(at!(0)) +} - #[must_use] - pub fn has_cached_override(&self, name: &str) -> bool { - self.cache - .overrides - .read() - .expect("locked for reading") - .contains_key(name) - } +#[implement(Cache)] +#[must_use] +pub async fn has_destination(&self, destination: &str) -> bool { + self.destinations.exists(destination).await.is_ok() +} + +#[implement(Cache)] +#[must_use] +pub async fn has_override(&self, destination: &str) -> bool { + self.overrides.exists(destination).await.is_ok() +} + +#[implement(Cache)] +pub fn destinations(&self) -> impl Stream + Send + '_ { + self.destinations + .stream() + .ignore_err() + .map(|item: (&ServerName, Cbor<_>)| (item.0, item.1 .0)) +} + +#[implement(Cache)] +pub fn overrides(&self) -> impl Stream + Send + '_ { + self.overrides + .stream() + .ignore_err() + .map(|item: (&ServerName, Cbor<_>)| (item.0, item.1 .0)) } impl CachedDest { diff --git a/src/service/resolver/dns.rs b/src/service/resolver/dns.rs index 5c9018ab..ad7768bc 100644 --- a/src/service/resolver/dns.rs +++ b/src/service/resolver/dns.rs @@ -88,18 +88,20 @@ impl Resolve for Resolver { impl Resolve for Hooked { fn resolve(&self, name: Name) -> Resolving { - let cached: Option = self - .cache - .overrides - .read() - .expect("locked for reading") - .get(name.as_str()) - .cloned(); + hooked_resolve(self.cache.clone(), self.server.clone(), self.resolver.clone(), name) + .boxed() + } +} - cached.map_or_else( - || resolve_to_reqwest(self.server.clone(), self.resolver.clone(), name).boxed(), - |cached| cached_to_reqwest(cached).boxed(), - ) +async fn hooked_resolve( + cache: Arc, + server: Arc, + resolver: Arc, + name: Name, +) -> Result> { + match cache.get_override(name.as_str()).await { + | Ok(cached) => cached_to_reqwest(cached).await, + | Err(_) => resolve_to_reqwest(server, resolver, name).boxed().await, } } diff --git a/src/service/resolver/fed.rs b/src/service/resolver/fed.rs index 76fc6894..bfe100e7 100644 --- a/src/service/resolver/fed.rs +++ b/src/service/resolver/fed.rs @@ -6,8 +6,9 @@ use std::{ use arrayvec::ArrayString; use conduwuit::utils::math::Expected; +use serde::{Deserialize, Serialize}; -#[derive(Clone, Debug, PartialEq, Eq)] +#[derive(Clone, Debug, Deserialize, PartialEq, Eq, Serialize)] pub enum FedDest { Literal(SocketAddr), Named(String, PortString), diff --git a/src/service/resolver/mod.rs b/src/service/resolver/mod.rs index 6a6289b6..3163b0d0 100644 --- a/src/service/resolver/mod.rs +++ b/src/service/resolver/mod.rs @@ -4,9 +4,9 @@ mod dns; pub mod fed; mod tests; -use std::{fmt::Write, sync::Arc}; +use std::sync::Arc; -use conduwuit::{utils, utils::math::Expected, Result, Server}; +use conduwuit::{Result, Server}; use self::{cache::Cache, dns::Resolver}; use crate::{client, Dep}; @@ -25,7 +25,7 @@ struct Services { impl crate::Service for Service { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] fn build(args: crate::Args<'_>) -> Result> { - let cache = Cache::new(); + let cache = Cache::new(&args); Ok(Arc::new(Self { cache: cache.clone(), resolver: Resolver::build(args.server, cache)?, @@ -36,38 +36,5 @@ impl crate::Service for Service { })) } - fn memory_usage(&self, out: &mut dyn Write) -> Result { - use utils::bytes::pretty; - - let (oc_count, oc_bytes) = self.cache.overrides.read()?.iter().fold( - (0_usize, 0_usize), - |(count, bytes), (key, val)| { - (count.expected_add(1), bytes.expected_add(key.len()).expected_add(val.size())) - }, - ); - - let (dc_count, dc_bytes) = self.cache.destinations.read()?.iter().fold( - (0_usize, 0_usize), - |(count, bytes), (key, val)| { - (count.expected_add(1), bytes.expected_add(key.len()).expected_add(val.size())) - }, - ); - - writeln!(out, "resolver_overrides_cache: {oc_count} ({})", pretty(oc_bytes))?; - writeln!(out, "resolver_destinations_cache: {dc_count} ({})", pretty(dc_bytes))?; - - Ok(()) - } - - fn clear_cache(&self) { - self.cache.overrides.write().expect("write locked").clear(); - self.cache - .destinations - .write() - .expect("write locked") - .clear(); - self.resolver.resolver.clear_cache(); - } - fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } diff --git a/src/service/sending/send.rs b/src/service/sending/send.rs index e2981068..831a1dd8 100644 --- a/src/service/sending/send.rs +++ b/src/service/sending/send.rs @@ -125,7 +125,7 @@ where let result = T::IncomingResponse::try_from_http_response(response); if result.is_ok() && !actual.cached { - resolver.set_cached_destination(dest.to_owned(), CachedDest { + resolver.cache.set_destination(dest, CachedDest { dest: actual.dest.clone(), host: actual.host.clone(), expire: CachedDest::default_expire(),