refactor appservice type stuff

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
Matthias Ahouansou 2024-03-22 19:21:51 -04:00 committed by June
parent 7c9c5b1d78
commit 60f2471f59
11 changed files with 125 additions and 133 deletions

View file

@ -58,7 +58,6 @@
- Basic validation/checks on user-specified room aliases and custom room ID creations - Basic validation/checks on user-specified room aliases and custom room ID creations
- Warn on unknown config options specified - Warn on unknown config options specified
- Add support for preventing certain room alias names and usernames using regex (via upstream MR) and extended to custom room IDs - Add support for preventing certain room alias names and usernames using regex (via upstream MR) and extended to custom room IDs
- Revamp appservice registration to ruma's `Registration` type which fixes various appservice registration issues, including fixing crashing upon no URL specified (via upstream MR)
- URL preview support (via upstream MR) with various improvements - URL preview support (via upstream MR) with various improvements
- Increased graceful shutdown timeout from a low 60 seconds to 180 seconds to avoid killing connections and let the remaining ones finish processing, and ask systemd for more time to shutdown if needed to prevent systemd's default [`TimeoutStopSec=`](https://www.freedesktop.org/software/systemd/man/latest/systemd.service.html#TimeoutStopSec=) of 90 seconds from killing conduwuit - Increased graceful shutdown timeout from a low 60 seconds to 180 seconds to avoid killing connections and let the remaining ones finish processing, and ask systemd for more time to shutdown if needed to prevent systemd's default [`TimeoutStopSec=`](https://www.freedesktop.org/software/systemd/man/latest/systemd.service.html#TimeoutStopSec=) of 90 seconds from killing conduwuit
- Bumped default max_concurrent_requests to 500 - Bumped default max_concurrent_requests to 500

View file

@ -14,15 +14,12 @@ pub(crate) async fn send_request<T>(registration: Registration, request: T) -> O
where where
T: OutgoingRequest + Debug, T: OutgoingRequest + Debug,
{ {
if let Some(destination) = registration.url { let destination = registration.url?;
let hs_token = registration.hs_token.as_str(); let hs_token = registration.hs_token.as_str();
let mut http_request = request let mut http_request = request
.try_into_http_request::<BytesMut>( .try_into_http_request::<BytesMut>(&destination, SendAccessToken::IfRequired(hs_token), &[MatrixVersion::V1_0])
&destination,
SendAccessToken::IfRequired(hs_token),
&[MatrixVersion::V1_0],
)
.map_err(|e| { .map_err(|e| {
warn!("Failed to find destination {}: {}", destination, e); warn!("Failed to find destination {}: {}", destination, e);
Error::BadServerResponse("Invalid destination") Error::BadServerResponse("Invalid destination")
@ -84,11 +81,9 @@ where
let response = T::IncomingResponse::try_from_http_response( let response = T::IncomingResponse::try_from_http_response(
http_response_builder.body(body).expect("reqwest body is valid http body"), http_response_builder.body(body).expect("reqwest body is valid http body"),
); );
Some(response.map_err(|_| { Some(response.map_err(|_| {
warn!("Appservice returned invalid response bytes {}\n{}", destination, url); warn!("Appservice returned invalid response bytes {}\n{}", destination, url);
Error::BadServerResponse("Server returned bad response.") Error::BadServerResponse("Server returned bad response.")
})) }))
} else {
None
}
} }

View file

@ -115,7 +115,7 @@ pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get
match services().rooms.alias.resolve_local_alias(&room_alias)? { match services().rooms.alias.resolve_local_alias(&room_alias)? {
Some(r) => room_id = Some(r), Some(r) => room_id = Some(r),
None => { None => {
for appservice in services().appservice.registration_info.read().await.values() { for appservice in services().appservice.read().await.values() {
if appservice.aliases.is_match(room_alias.as_str()) if appservice.aliases.is_match(room_alias.as_str())
&& if let Some(opt_result) = services() && if let Some(opt_result) = services()
.sending .sending

View file

@ -76,11 +76,13 @@ where
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let appservices = services().appservice.all().unwrap(); let appservice_registration = if let Some(token) = token {
let appservice_registration = services().appservice.find_from_token(token).await
appservices.iter().find(|(_id, registration)| Some(registration.as_token.as_str()) == token); } else {
None
};
let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) = let (sender_user, sender_device, sender_servername, from_appservice) = if let Some(info) =
appservice_registration appservice_registration
{ {
match metadata.authentication { match metadata.authentication {
@ -88,7 +90,7 @@ where
let user_id = query_params.user_id.map_or_else( let user_id = query_params.user_id.map_or_else(
|| { || {
UserId::parse_with_server_name( UserId::parse_with_server_name(
registration.sender_localpart.as_str(), info.registration.sender_localpart.as_str(),
services().globals.server_name(), services().globals.server_name(),
) )
.unwrap() .unwrap()
@ -109,7 +111,7 @@ where
let user_id = query_params.user_id.map_or_else( let user_id = query_params.user_id.map_or_else(
|| { || {
UserId::parse_with_server_name( UserId::parse_with_server_name(
registration.sender_localpart.as_str(), info.registration.sender_localpart.as_str(),
services().globals.server_name(), services().globals.server_name(),
) )
.unwrap() .unwrap()

View file

@ -7,7 +7,6 @@ impl service::appservice::Data for KeyValueDatabase {
fn register_appservice(&self, yaml: Registration) -> Result<String> { fn register_appservice(&self, yaml: Registration) -> Result<String> {
let id = yaml.id.as_str(); let id = yaml.id.as_str();
self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
self.cached_registrations.write().unwrap().insert(id.to_owned(), yaml.clone());
Ok(id.to_owned()) Ok(id.to_owned())
} }
@ -19,24 +18,17 @@ impl service::appservice::Data for KeyValueDatabase {
/// * `service_name` - the name you send to register the service previously /// * `service_name` - the name you send to register the service previously
fn unregister_appservice(&self, service_name: &str) -> Result<()> { fn unregister_appservice(&self, service_name: &str) -> Result<()> {
self.id_appserviceregistrations.remove(service_name.as_bytes())?; self.id_appserviceregistrations.remove(service_name.as_bytes())?;
self.cached_registrations.write().unwrap().remove(service_name);
Ok(()) Ok(())
} }
fn get_registration(&self, id: &str) -> Result<Option<Registration>> { fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
self.cached_registrations.read().unwrap().get(id).map_or_else(
|| {
self.id_appserviceregistrations self.id_appserviceregistrations
.get(id.as_bytes())? .get(id.as_bytes())?
.map(|bytes| { .map(|bytes| {
serde_yaml::from_slice(&bytes).map_err(|_| { serde_yaml::from_slice(&bytes)
Error::bad_database("Invalid registration bytes in id_appserviceregistrations.") .map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations."))
})
}) })
.transpose() .transpose()
},
|r| Ok(Some(r.clone())),
)
} }
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> { fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {

View file

@ -17,7 +17,6 @@ use itertools::Itertools;
use lru_cache::LruCache; use lru_cache::LruCache;
use rand::thread_rng; use rand::thread_rng;
use ruma::{ use ruma::{
api::appservice::Registration,
events::{ events::{
push_rules::{PushRulesEvent, PushRulesEventContent}, push_rules::{PushRulesEvent, PushRulesEventContent},
room::message::RoomMessageEventContent, room::message::RoomMessageEventContent,
@ -179,7 +178,6 @@ pub struct KeyValueDatabase {
//pub pusher: pusher::PushData, //pub pusher: pusher::PushData,
pub(super) senderkey_pusher: Arc<dyn KvTree>, pub(super) senderkey_pusher: Arc<dyn KvTree>,
pub(super) cached_registrations: Arc<RwLock<HashMap<String, Registration>>>,
pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>, pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>,
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>, pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>, pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
@ -379,7 +377,6 @@ impl KeyValueDatabase {
global: builder.open_tree("global")?, global: builder.open_tree("global")?,
server_signingkeys: builder.open_tree("server_signingkeys")?, server_signingkeys: builder.open_tree("server_signingkeys")?,
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
pdu_cache: Mutex::new(LruCache::new( pdu_cache: Mutex::new(LruCache::new(
config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"), config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"),
)), )),
@ -992,14 +989,6 @@ impl KeyValueDatabase {
); );
} }
// Inserting registrations into cache
for appservice in services().appservice.all()? {
services().appservice.registration_info.write().await.insert(
appservice.0,
appservice.1.try_into().expect("Should be validated on registration"),
);
}
services().admin.start_handler(); services().admin.start_handler();
// Set emergency access for the conduit user // Set emergency access for the conduit user

View file

@ -623,8 +623,8 @@ impl Service {
}, },
AppserviceCommand::Show { AppserviceCommand::Show {
appservice_identifier, appservice_identifier,
} => match services().appservice.get_registration(&appservice_identifier) { } => match services().appservice.get_registration(&appservice_identifier).await {
Ok(Some(config)) => { Some(config) => {
let config_str = let config_str =
serde_yaml::to_string(&config).expect("config should've been validated on register"); serde_yaml::to_string(&config).expect("config should've been validated on register");
let output = format!("Config for {}:\n\n```yaml\n{}\n```", appservice_identifier, config_str,); let output = format!("Config for {}:\n\n```yaml\n{}\n```", appservice_identifier, config_str,);
@ -635,21 +635,12 @@ impl Service {
); );
RoomMessageEventContent::text_html(output, output_html) RoomMessageEventContent::text_html(output, output_html)
}, },
Ok(None) => RoomMessageEventContent::text_plain("Appservice does not exist."), None => RoomMessageEventContent::text_plain("Appservice does not exist."),
Err(_) => RoomMessageEventContent::text_plain("Failed to get appservice."),
}, },
AppserviceCommand::List => { AppserviceCommand::List => {
if let Ok(appservices) = services().appservice.iter_ids().map(Iterator::collect::<Vec<_>>) { let appservices = services().appservice.iter_ids().await;
let count = appservices.len(); let output = format!("Appservices ({}): {}", appservices.len(), appservices.join(", "));
let output = format!(
"Appservices ({}): {}",
count,
appservices.into_iter().filter_map(std::result::Result::ok).collect::<Vec<_>>().join(", ")
);
RoomMessageEventContent::text_plain(output) RoomMessageEventContent::text_plain(output)
} else {
RoomMessageEventContent::text_plain("Failed to get appservices.")
}
}, },
}, },
AdminCommand::Media(command) => { AdminCommand::Media(command) => {

View file

@ -1,8 +1,9 @@
mod data; mod data;
use std::collections::HashMap; use std::collections::BTreeMap;
pub(crate) use data::Data; pub(crate) use data::Data;
use futures_util::Future;
use regex::RegexSet; use regex::RegexSet;
use ruma::api::appservice::{Namespace, Registration}; use ruma::api::appservice::{Namespace, Registration};
use tokio::sync::RwLock; use tokio::sync::RwLock;
@ -10,6 +11,7 @@ use tokio::sync::RwLock;
use crate::{services, Result}; use crate::{services, Result};
/// Compiled regular expressions for a namespace /// Compiled regular expressions for a namespace
#[derive(Clone, Debug)]
pub struct NamespaceRegex { pub struct NamespaceRegex {
pub exclusive: Option<RegexSet>, pub exclusive: Option<RegexSet>,
pub non_exclusive: Option<RegexSet>, pub non_exclusive: Option<RegexSet>,
@ -71,7 +73,8 @@ impl TryFrom<Vec<Namespace>> for NamespaceRegex {
} }
} }
/// Compiled regular expressions for an appservice /// Appservice registration combined with its compiled regular expressions.
#[derive(Clone, Debug)]
pub struct RegistrationInfo { pub struct RegistrationInfo {
pub registration: Registration, pub registration: Registration,
pub users: NamespaceRegex, pub users: NamespaceRegex,
@ -94,10 +97,26 @@ impl TryFrom<Registration> for RegistrationInfo {
pub struct Service { pub struct Service {
pub db: &'static dyn Data, pub db: &'static dyn Data,
pub registration_info: RwLock<HashMap<String, RegistrationInfo>>, registration_info: RwLock<BTreeMap<String, RegistrationInfo>>,
} }
impl Service { impl Service {
pub fn build(db: &'static dyn Data) -> Result<Self> {
let mut registration_info = BTreeMap::new();
// Inserting registrations into cache
for appservice in db.all()? {
registration_info.insert(
appservice.0,
appservice.1.try_into().expect("Should be validated on registration"),
);
}
Ok(Self {
db,
registration_info: RwLock::new(registration_info),
})
}
/// Registers an appservice and returns the ID to the caller /// Registers an appservice and returns the ID to the caller
pub async fn register_appservice(&self, yaml: Registration) -> Result<String> { pub async fn register_appservice(&self, yaml: Registration) -> Result<String> {
services().appservice.registration_info.write().await.insert(yaml.id.clone(), yaml.clone().try_into()?); services().appservice.registration_info.write().await.insert(yaml.id.clone(), yaml.clone().try_into()?);
@ -116,9 +135,17 @@ impl Service {
self.db.unregister_appservice(service_name) self.db.unregister_appservice(service_name)
} }
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> { self.db.get_registration(id) } pub async fn get_registration(&self, id: &str) -> Option<Registration> {
self.registration_info.read().await.get(id).cloned().map(|info| info.registration)
}
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { self.db.iter_ids() } pub async fn iter_ids(&self) -> Vec<String> { self.registration_info.read().await.keys().cloned().collect() }
pub fn all(&self) -> Result<Vec<(String, Registration)>> { self.db.all() } pub async fn find_from_token(&self, token: &str) -> Option<RegistrationInfo> {
self.read().await.values().find(|info| info.registration.as_token == token).cloned()
}
pub fn read(&self) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> {
self.registration_info.read()
}
} }

View file

@ -55,10 +55,7 @@ impl Services<'_> {
db: &'static D, config: Config, db: &'static D, config: Config,
) -> Result<Self> { ) -> Result<Self> {
Ok(Self { Ok(Self {
appservice: appservice::Service { appservice: appservice::Service::build(db)?,
db,
registration_info: RwLock::new(HashMap::new()),
},
pusher: pusher::Service { pusher: pusher::Service {
db, db,
}, },

View file

@ -510,7 +510,7 @@ impl Service {
} }
} }
for appservice in services().appservice.registration_info.read().await.values() { for appservice in services().appservice.read().await.values() {
if services().rooms.state_cache.appservice_in_room(&pdu.room_id, appservice)? { if services().rooms.state_cache.appservice_in_room(&pdu.room_id, appservice)? {
services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
continue; continue;

View file

@ -502,7 +502,7 @@ impl Service {
let permit = services().sending.maximum_requests.acquire().await; let permit = services().sending.maximum_requests.acquire().await;
let response = match appservice_server::send_request( let response = match appservice_server::send_request(
services().appservice.get_registration(id).map_err(|e| (kind.clone(), e))?.ok_or_else(|| { services().appservice.get_registration(id).await.ok_or_else(|| {
( (
kind.clone(), kind.clone(),
Error::bad_database("[Appservice] Could not load registration from db."), Error::bad_database("[Appservice] Could not load registration from db."),