diff --git a/src/admin/query/appservice.rs b/src/admin/query/appservice.rs index 93c76a7e..0359261a 100644 --- a/src/admin/query/appservice.rs +++ b/src/admin/query/appservice.rs @@ -1,5 +1,6 @@ use clap::Subcommand; use conduwuit::Result; +use futures::TryStreamExt; use crate::Command; @@ -31,7 +32,7 @@ pub(super) async fn process(subcommand: AppserviceCommand, context: &Command<'_> }, | AppserviceCommand::All => { let timer = tokio::time::Instant::now(); - let results = services.appservice.all().await; + let results: Vec<_> = services.appservice.iter_db_ids().try_collect().await?; let query_time = timer.elapsed(); write!(context, "Query completed in {query_time:?}:\n\n```rs\n{results:#?}\n```") diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index 50a60033..7be8a471 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -1,20 +1,20 @@ mod namespace_regex; mod registration_info; -use std::{collections::BTreeMap, sync::Arc}; +use std::{collections::BTreeMap, iter::IntoIterator, sync::Arc}; use async_trait::async_trait; -use conduwuit::{Result, err, utils::stream::TryIgnore}; +use conduwuit::{Result, err, utils::stream::IterStream}; use database::Map; -use futures::{Future, StreamExt, TryStreamExt}; +use futures::{Future, FutureExt, Stream, TryStreamExt}; use ruma::{RoomAliasId, RoomId, UserId, api::appservice::Registration}; -use tokio::sync::RwLock; +use tokio::sync::{RwLock, RwLockReadGuard}; pub use self::{namespace_regex::NamespaceRegex, registration_info::RegistrationInfo}; use crate::{Dep, sending}; pub struct Service { - registration_info: RwLock>, + registration_info: RwLock, services: Services, db: Data, } @@ -27,6 +27,8 @@ struct Data { id_appserviceregistrations: Arc, } +type Registrations = BTreeMap; + #[async_trait] impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -41,19 +43,18 @@ impl crate::Service for Service { })) } - async fn worker(self: Arc) -> Result<()> { + async fn worker(self: Arc) -> Result { // Inserting registrations into cache - for appservice in self.iter_db_ids().await? { - self.registration_info.write().await.insert( - appservice.0, - appservice - .1 - .try_into() - .expect("Should be validated on registration"), - ); - } + self.iter_db_ids() + .try_for_each(async |appservice| { + self.registration_info + .write() + .await + .insert(appservice.0, appservice.1.try_into()?); - Ok(()) + Ok(()) + }) + .await } fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } @@ -84,7 +85,7 @@ impl Service { /// # Arguments /// /// * `service_name` - the registration ID of the appservice - pub async fn unregister_appservice(&self, appservice_id: &str) -> Result<()> { + pub async fn unregister_appservice(&self, appservice_id: &str) -> Result { // removes the appservice registration info self.registration_info .write() @@ -112,15 +113,6 @@ impl Service { .map(|info| info.registration) } - pub async fn iter_ids(&self) -> Vec { - self.registration_info - .read() - .await - .keys() - .cloned() - .collect() - } - pub async fn find_from_token(&self, token: &str) -> Option { self.read() .await @@ -156,15 +148,22 @@ impl Service { .any(|info| info.rooms.is_exclusive_match(room_id.as_str())) } - pub fn read( - &self, - ) -> impl Future>> - { - self.registration_info.read() + pub fn iter_ids(&self) -> impl Stream + Send { + self.read() + .map(|info| info.keys().cloned().collect::>()) + .map(IntoIterator::into_iter) + .map(IterStream::stream) + .flatten_stream() } - #[inline] - pub async fn all(&self) -> Result> { self.iter_db_ids().await } + pub fn iter_db_ids(&self) -> impl Stream> + Send { + self.db + .id_appserviceregistrations + .keys() + .and_then(move |id: &str| async move { + Ok((id.to_owned(), self.get_db_registration(id).await?)) + }) + } pub async fn get_db_registration(&self, id: &str) -> Result { self.db @@ -175,16 +174,7 @@ impl Service { .map_err(|e| err!(Database("Invalid appservice {id:?} registration: {e:?}"))) } - async fn iter_db_ids(&self) -> Result> { - self.db - .id_appserviceregistrations - .keys() - .ignore_err() - .then(|id: String| async move { - let reg = self.get_db_registration(&id).await?; - Ok((id, reg)) - }) - .try_collect() - .await + pub fn read(&self) -> impl Future> + Send { + self.registration_info.read() } }