diff --git a/src/service/resolver/cache.rs b/src/service/resolver/cache.rs index 6b05c00c..cfea7187 100644 --- a/src/service/resolver/cache.rs +++ b/src/service/resolver/cache.rs @@ -7,7 +7,7 @@ use conduwuit::{ utils::{math::Expected, rand, stream::TryIgnore}, }; use database::{Cbor, Deserialized, Map}; -use futures::{Stream, StreamExt}; +use futures::{Stream, StreamExt, future::join}; use ruma::ServerName; use serde::{Deserialize, Serialize}; @@ -45,6 +45,21 @@ impl Cache { } } +#[implement(Cache)] +pub async fn clear(&self) { join(self.clear_destinations(), self.clear_overrides()).await; } + +#[implement(Cache)] +pub async fn clear_destinations(&self) { self.destinations.clear().await; } + +#[implement(Cache)] +pub async fn clear_overrides(&self) { self.overrides.clear().await; } + +#[implement(Cache)] +pub fn del_destination(&self, name: &ServerName) { self.destinations.remove(name); } + +#[implement(Cache)] +pub fn del_override(&self, name: &ServerName) { self.overrides.remove(name); } + #[implement(Cache)] pub fn set_destination(&self, name: &ServerName, dest: &CachedDest) { self.destinations.raw_put(name, Cbor(dest)); diff --git a/src/service/resolver/dns.rs b/src/service/resolver/dns.rs index 98ad7e60..e4245a5b 100644 --- a/src/service/resolver/dns.rs +++ b/src/service/resolver/dns.rs @@ -78,6 +78,10 @@ impl Resolver { server: server.clone(), })) } + + /// Clear the in-memory hickory-dns caches + #[inline] + pub fn clear_cache(&self) { self.resolver.clear_cache(); } } impl Resolve for Resolver { diff --git a/src/service/resolver/mod.rs b/src/service/resolver/mod.rs index 2ec9c0ef..246d6bc1 100644 --- a/src/service/resolver/mod.rs +++ b/src/service/resolver/mod.rs @@ -6,6 +6,7 @@ mod tests; use std::sync::Arc; +use async_trait::async_trait; use conduwuit::{Result, Server, arrayvec::ArrayString, utils::MutexMap}; use self::{cache::Cache, dns::Resolver}; @@ -26,6 +27,7 @@ struct Services { type Resolving = MutexMap; type NameBuf = ArrayString<256>; +#[async_trait] impl crate::Service for Service { #[allow(clippy::as_conversions, clippy::cast_sign_loss, clippy::cast_possible_truncation)] fn build(args: crate::Args<'_>) -> Result> { @@ -41,5 +43,10 @@ impl crate::Service for Service { })) } + async fn clear_cache(&self) { + self.resolver.clear_cache(); + self.cache.clear().await; + } + fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } }