refactor appservice type stuff
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
parent
7c9c5b1d78
commit
60f2471f59
11 changed files with 125 additions and 133 deletions
|
@ -14,81 +14,76 @@ pub(crate) async fn send_request<T>(registration: Registration, request: T) -> O
|
|||
where
|
||||
T: OutgoingRequest + Debug,
|
||||
{
|
||||
if let Some(destination) = registration.url {
|
||||
let hs_token = registration.hs_token.as_str();
|
||||
let destination = registration.url?;
|
||||
|
||||
let mut http_request = request
|
||||
.try_into_http_request::<BytesMut>(
|
||||
&destination,
|
||||
SendAccessToken::IfRequired(hs_token),
|
||||
&[MatrixVersion::V1_0],
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})
|
||||
.unwrap()
|
||||
.map(BytesMut::freeze);
|
||||
let hs_token = registration.hs_token.as_str();
|
||||
|
||||
let mut parts = http_request.uri().clone().into_parts();
|
||||
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
|
||||
let symbol = if old_path_and_query.contains('?') {
|
||||
"&"
|
||||
} else {
|
||||
"?"
|
||||
};
|
||||
let mut http_request = request
|
||||
.try_into_http_request::<BytesMut>(&destination, SendAccessToken::IfRequired(hs_token), &[MatrixVersion::V1_0])
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
})
|
||||
.unwrap()
|
||||
.map(BytesMut::freeze);
|
||||
|
||||
parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap());
|
||||
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
|
||||
|
||||
let mut reqwest_request =
|
||||
reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests");
|
||||
|
||||
*reqwest_request.timeout_mut() = Some(Duration::from_secs(120));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let mut response = match services().globals.client.appservice.execute(reqwest_request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Could not send request to appservice {} at {}: {}",
|
||||
registration.id, destination, e
|
||||
);
|
||||
return Some(Err(e.into()));
|
||||
},
|
||||
};
|
||||
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder().status(status).version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder.headers_mut().expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error: {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
|
||||
if !status.is_success() {
|
||||
warn!(
|
||||
"Appservice returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder.body(body).expect("reqwest body is valid http body"),
|
||||
);
|
||||
Some(response.map_err(|_| {
|
||||
warn!("Appservice returned invalid response bytes {}\n{}", destination, url);
|
||||
Error::BadServerResponse("Server returned bad response.")
|
||||
}))
|
||||
let mut parts = http_request.uri().clone().into_parts();
|
||||
let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned();
|
||||
let symbol = if old_path_and_query.contains('?') {
|
||||
"&"
|
||||
} else {
|
||||
None
|
||||
"?"
|
||||
};
|
||||
|
||||
parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap());
|
||||
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
|
||||
|
||||
let mut reqwest_request =
|
||||
reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests");
|
||||
|
||||
*reqwest_request.timeout_mut() = Some(Duration::from_secs(120));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let mut response = match services().globals.client.appservice.execute(reqwest_request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Could not send request to appservice {} at {}: {}",
|
||||
registration.id, destination, e
|
||||
);
|
||||
return Some(Err(e.into()));
|
||||
},
|
||||
};
|
||||
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder().status(status).version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder.headers_mut().expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
warn!("server error: {}", e);
|
||||
Vec::new().into()
|
||||
}); // TODO: handle timeout
|
||||
|
||||
if !status.is_success() {
|
||||
warn!(
|
||||
"Appservice returned bad response {} {}\n{}\n{:?}",
|
||||
destination,
|
||||
status,
|
||||
url,
|
||||
utils::string_from_bytes(&body)
|
||||
);
|
||||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder.body(body).expect("reqwest body is valid http body"),
|
||||
);
|
||||
|
||||
Some(response.map_err(|_| {
|
||||
warn!("Appservice returned invalid response bytes {}\n{}", destination, url);
|
||||
Error::BadServerResponse("Server returned bad response.")
|
||||
}))
|
||||
}
|
||||
|
|
|
@ -115,7 +115,7 @@ pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get
|
|||
match services().rooms.alias.resolve_local_alias(&room_alias)? {
|
||||
Some(r) => room_id = Some(r),
|
||||
None => {
|
||||
for appservice in services().appservice.registration_info.read().await.values() {
|
||||
for appservice in services().appservice.read().await.values() {
|
||||
if appservice.aliases.is_match(room_alias.as_str())
|
||||
&& if let Some(opt_result) = services()
|
||||
.sending
|
||||
|
|
|
@ -76,11 +76,13 @@ where
|
|||
|
||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||
|
||||
let appservices = services().appservice.all().unwrap();
|
||||
let appservice_registration =
|
||||
appservices.iter().find(|(_id, registration)| Some(registration.as_token.as_str()) == token);
|
||||
let appservice_registration = if let Some(token) = token {
|
||||
services().appservice.find_from_token(token).await
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) =
|
||||
let (sender_user, sender_device, sender_servername, from_appservice) = if let Some(info) =
|
||||
appservice_registration
|
||||
{
|
||||
match metadata.authentication {
|
||||
|
@ -88,7 +90,7 @@ where
|
|||
let user_id = query_params.user_id.map_or_else(
|
||||
|| {
|
||||
UserId::parse_with_server_name(
|
||||
registration.sender_localpart.as_str(),
|
||||
info.registration.sender_localpart.as_str(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.unwrap()
|
||||
|
@ -109,7 +111,7 @@ where
|
|||
let user_id = query_params.user_id.map_or_else(
|
||||
|| {
|
||||
UserId::parse_with_server_name(
|
||||
registration.sender_localpart.as_str(),
|
||||
info.registration.sender_localpart.as_str(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
.unwrap()
|
||||
|
|
|
@ -7,7 +7,6 @@ impl service::appservice::Data for KeyValueDatabase {
|
|||
fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
let id = yaml.id.as_str();
|
||||
self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
|
||||
self.cached_registrations.write().unwrap().insert(id.to_owned(), yaml.clone());
|
||||
|
||||
Ok(id.to_owned())
|
||||
}
|
||||
|
@ -19,24 +18,17 @@ impl service::appservice::Data for KeyValueDatabase {
|
|||
/// * `service_name` - the name you send to register the service previously
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.id_appserviceregistrations.remove(service_name.as_bytes())?;
|
||||
self.cached_registrations.write().unwrap().remove(service_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
||||
self.cached_registrations.read().unwrap().get(id).map_or_else(
|
||||
|| {
|
||||
self.id_appserviceregistrations
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
serde_yaml::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
},
|
||||
|r| Ok(Some(r.clone())),
|
||||
)
|
||||
self.id_appserviceregistrations
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
serde_yaml::from_slice(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid registration bytes in id_appserviceregistrations."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
|
||||
|
|
|
@ -17,7 +17,6 @@ use itertools::Itertools;
|
|||
use lru_cache::LruCache;
|
||||
use rand::thread_rng;
|
||||
use ruma::{
|
||||
api::appservice::Registration,
|
||||
events::{
|
||||
push_rules::{PushRulesEvent, PushRulesEventContent},
|
||||
room::message::RoomMessageEventContent,
|
||||
|
@ -179,7 +178,6 @@ pub struct KeyValueDatabase {
|
|||
//pub pusher: pusher::PushData,
|
||||
pub(super) senderkey_pusher: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) cached_registrations: Arc<RwLock<HashMap<String, Registration>>>,
|
||||
pub(super) pdu_cache: Mutex<LruCache<OwnedEventId, Arc<PduEvent>>>,
|
||||
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
|
||||
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
|
||||
|
@ -379,7 +377,6 @@ impl KeyValueDatabase {
|
|||
global: builder.open_tree("global")?,
|
||||
server_signingkeys: builder.open_tree("server_signingkeys")?,
|
||||
|
||||
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
|
||||
pdu_cache: Mutex::new(LruCache::new(
|
||||
config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"),
|
||||
)),
|
||||
|
@ -992,14 +989,6 @@ impl KeyValueDatabase {
|
|||
);
|
||||
}
|
||||
|
||||
// Inserting registrations into cache
|
||||
for appservice in services().appservice.all()? {
|
||||
services().appservice.registration_info.write().await.insert(
|
||||
appservice.0,
|
||||
appservice.1.try_into().expect("Should be validated on registration"),
|
||||
);
|
||||
}
|
||||
|
||||
services().admin.start_handler();
|
||||
|
||||
// Set emergency access for the conduit user
|
||||
|
|
|
@ -623,8 +623,8 @@ impl Service {
|
|||
},
|
||||
AppserviceCommand::Show {
|
||||
appservice_identifier,
|
||||
} => match services().appservice.get_registration(&appservice_identifier) {
|
||||
Ok(Some(config)) => {
|
||||
} => match services().appservice.get_registration(&appservice_identifier).await {
|
||||
Some(config) => {
|
||||
let config_str =
|
||||
serde_yaml::to_string(&config).expect("config should've been validated on register");
|
||||
let output = format!("Config for {}:\n\n```yaml\n{}\n```", appservice_identifier, config_str,);
|
||||
|
@ -635,21 +635,12 @@ impl Service {
|
|||
);
|
||||
RoomMessageEventContent::text_html(output, output_html)
|
||||
},
|
||||
Ok(None) => RoomMessageEventContent::text_plain("Appservice does not exist."),
|
||||
Err(_) => RoomMessageEventContent::text_plain("Failed to get appservice."),
|
||||
None => RoomMessageEventContent::text_plain("Appservice does not exist."),
|
||||
},
|
||||
AppserviceCommand::List => {
|
||||
if let Ok(appservices) = services().appservice.iter_ids().map(Iterator::collect::<Vec<_>>) {
|
||||
let count = appservices.len();
|
||||
let output = format!(
|
||||
"Appservices ({}): {}",
|
||||
count,
|
||||
appservices.into_iter().filter_map(std::result::Result::ok).collect::<Vec<_>>().join(", ")
|
||||
);
|
||||
RoomMessageEventContent::text_plain(output)
|
||||
} else {
|
||||
RoomMessageEventContent::text_plain("Failed to get appservices.")
|
||||
}
|
||||
let appservices = services().appservice.iter_ids().await;
|
||||
let output = format!("Appservices ({}): {}", appservices.len(), appservices.join(", "));
|
||||
RoomMessageEventContent::text_plain(output)
|
||||
},
|
||||
},
|
||||
AdminCommand::Media(command) => {
|
||||
|
|
|
@ -1,8 +1,9 @@
|
|||
mod data;
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
pub(crate) use data::Data;
|
||||
use futures_util::Future;
|
||||
use regex::RegexSet;
|
||||
use ruma::api::appservice::{Namespace, Registration};
|
||||
use tokio::sync::RwLock;
|
||||
|
@ -10,6 +11,7 @@ use tokio::sync::RwLock;
|
|||
use crate::{services, Result};
|
||||
|
||||
/// Compiled regular expressions for a namespace
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct NamespaceRegex {
|
||||
pub exclusive: Option<RegexSet>,
|
||||
pub non_exclusive: Option<RegexSet>,
|
||||
|
@ -71,7 +73,8 @@ impl TryFrom<Vec<Namespace>> for NamespaceRegex {
|
|||
}
|
||||
}
|
||||
|
||||
/// Compiled regular expressions for an appservice
|
||||
/// Appservice registration combined with its compiled regular expressions.
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct RegistrationInfo {
|
||||
pub registration: Registration,
|
||||
pub users: NamespaceRegex,
|
||||
|
@ -94,10 +97,26 @@ impl TryFrom<Registration> for RegistrationInfo {
|
|||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
pub registration_info: RwLock<HashMap<String, RegistrationInfo>>,
|
||||
registration_info: RwLock<BTreeMap<String, RegistrationInfo>>,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn build(db: &'static dyn Data) -> Result<Self> {
|
||||
let mut registration_info = BTreeMap::new();
|
||||
// Inserting registrations into cache
|
||||
for appservice in db.all()? {
|
||||
registration_info.insert(
|
||||
appservice.0,
|
||||
appservice.1.try_into().expect("Should be validated on registration"),
|
||||
);
|
||||
}
|
||||
|
||||
Ok(Self {
|
||||
db,
|
||||
registration_info: RwLock::new(registration_info),
|
||||
})
|
||||
}
|
||||
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
pub async fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
services().appservice.registration_info.write().await.insert(yaml.id.clone(), yaml.clone().try_into()?);
|
||||
|
@ -116,9 +135,17 @@ impl Service {
|
|||
self.db.unregister_appservice(service_name)
|
||||
}
|
||||
|
||||
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> { self.db.get_registration(id) }
|
||||
pub async fn get_registration(&self, id: &str) -> Option<Registration> {
|
||||
self.registration_info.read().await.get(id).cloned().map(|info| info.registration)
|
||||
}
|
||||
|
||||
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { self.db.iter_ids() }
|
||||
pub async fn iter_ids(&self) -> Vec<String> { self.registration_info.read().await.keys().cloned().collect() }
|
||||
|
||||
pub fn all(&self) -> Result<Vec<(String, Registration)>> { self.db.all() }
|
||||
pub async fn find_from_token(&self, token: &str) -> Option<RegistrationInfo> {
|
||||
self.read().await.values().find(|info| info.registration.as_token == token).cloned()
|
||||
}
|
||||
|
||||
pub fn read(&self) -> impl Future<Output = tokio::sync::RwLockReadGuard<'_, BTreeMap<String, RegistrationInfo>>> {
|
||||
self.registration_info.read()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -55,10 +55,7 @@ impl Services<'_> {
|
|||
db: &'static D, config: Config,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
appservice: appservice::Service {
|
||||
db,
|
||||
registration_info: RwLock::new(HashMap::new()),
|
||||
},
|
||||
appservice: appservice::Service::build(db)?,
|
||||
pusher: pusher::Service {
|
||||
db,
|
||||
},
|
||||
|
|
|
@ -510,7 +510,7 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
for appservice in services().appservice.registration_info.read().await.values() {
|
||||
for appservice in services().appservice.read().await.values() {
|
||||
if services().rooms.state_cache.appservice_in_room(&pdu.room_id, appservice)? {
|
||||
services().sending.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
|
||||
continue;
|
||||
|
|
|
@ -502,7 +502,7 @@ impl Service {
|
|||
let permit = services().sending.maximum_requests.acquire().await;
|
||||
|
||||
let response = match appservice_server::send_request(
|
||||
services().appservice.get_registration(id).map_err(|e| (kind.clone(), e))?.ok_or_else(|| {
|
||||
services().appservice.get_registration(id).await.ok_or_else(|| {
|
||||
(
|
||||
kind.clone(),
|
||||
Error::bad_database("[Appservice] Could not load registration from db."),
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue