de-global services() from api

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-07-16 08:05:25 +00:00
parent 463f1a1287
commit 8b6018d77d
61 changed files with 1485 additions and 1320 deletions

View file

@ -530,7 +530,7 @@ pub(super) async fn force_set_room_state_from_server(
for result in remote_state_response for result in remote_state_response
.pdus .pdus
.iter() .iter()
.map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map)) .map(|pdu| validate_and_add_event_id(services(), pdu, &room_version, &pub_key_map))
{ {
let Ok((event_id, value)) = result.await else { let Ok((event_id, value)) = result.await else {
continue; continue;
@ -558,7 +558,7 @@ pub(super) async fn force_set_room_state_from_server(
for result in remote_state_response for result in remote_state_response
.auth_chain .auth_chain
.iter() .iter()
.map(|pdu| validate_and_add_event_id(pdu, &room_version, &pub_key_map)) .map(|pdu| validate_and_add_event_id(services(), pdu, &room_version, &pub_key_map))
{ {
let Ok((event_id, value)) = result.await else { let Ok((event_id, value)) = result.await else {
continue; continue;

View file

@ -128,7 +128,7 @@ async fn ban_room(
&local_user, &room_id &local_user, &room_id
); );
if let Err(e) = leave_room(&local_user, &room_id, None).await { if let Err(e) = leave_room(services(), &local_user, &room_id, None).await {
warn!(%e, "Failed to leave room"); warn!(%e, "Failed to leave room");
} }
} }
@ -151,7 +151,7 @@ async fn ban_room(
}) })
}) { }) {
debug!("Attempting leave for user {} in room {}", &local_user, &room_id); debug!("Attempting leave for user {} in room {}", &local_user, &room_id);
if let Err(e) = leave_room(&local_user, &room_id, None).await { if let Err(e) = leave_room(services(), &local_user, &room_id, None).await {
error!( error!(
"Error attempting to make local user {} leave room {} during room banning: {}", "Error attempting to make local user {} leave room {} during room banning: {}",
&local_user, &room_id, e &local_user, &room_id, e
@ -334,7 +334,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo
"Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)",
&local_user, room_id &local_user, room_id
); );
if let Err(e) = leave_room(&local_user, &room_id, None).await { if let Err(e) = leave_room(services(), &local_user, &room_id, None).await {
warn!(%e, "Failed to leave room"); warn!(%e, "Failed to leave room");
} }
} }
@ -358,7 +358,7 @@ async fn ban_list_of_rooms(body: Vec<&str>, force: bool, disable_federation: boo
}) })
}) { }) {
debug!("Attempting leave for user {} in room {}", &local_user, &room_id); debug!("Attempting leave for user {} in room {}", &local_user, &room_id);
if let Err(e) = leave_room(&local_user, &room_id, None).await { if let Err(e) = leave_room(services(), &local_user, &room_id, None).await {
error!( error!(
"Error attempting to make local user {} leave room {} during bulk room banning: {}", "Error attempting to make local user {} leave room {} during bulk room banning: {}",
&local_user, &room_id, e &local_user, &room_id, e

View file

@ -101,6 +101,7 @@ pub(super) async fn create(
if let Some(room_id_server_name) = room.server_name() { if let Some(room_id_server_name) = room.server_name() {
match join_room_by_id_helper( match join_room_by_id_helper(
services(),
&user_id, &user_id,
room, room,
Some("Automatically joining this room upon registration".to_owned()), Some("Automatically joining this room upon registration".to_owned()),
@ -158,9 +159,9 @@ pub(super) async fn deactivate(
.rooms_joined(&user_id) .rooms_joined(&user_id)
.filter_map(Result::ok) .filter_map(Result::ok)
.collect(); .collect();
update_displayname(user_id.clone(), None, all_joined_rooms.clone()).await?; update_displayname(services(), user_id.clone(), None, all_joined_rooms.clone()).await?;
update_avatar_url(user_id.clone(), None, None, all_joined_rooms).await?; update_avatar_url(services(), user_id.clone(), None, None, all_joined_rooms).await?;
leave_all_rooms(&user_id).await; leave_all_rooms(services(), &user_id).await;
} }
Ok(RoomMessageEventContent::text_plain(format!( Ok(RoomMessageEventContent::text_plain(format!(
@ -262,9 +263,9 @@ pub(super) async fn deactivate_all(
.rooms_joined(&user_id) .rooms_joined(&user_id)
.filter_map(Result::ok) .filter_map(Result::ok)
.collect(); .collect();
update_displayname(user_id.clone(), None, all_joined_rooms.clone()).await?; update_displayname(services(), user_id.clone(), None, all_joined_rooms.clone()).await?;
update_avatar_url(user_id.clone(), None, None, all_joined_rooms).await?; update_avatar_url(services(), user_id.clone(), None, None, all_joined_rooms).await?;
leave_all_rooms(&user_id).await; leave_all_rooms(services(), &user_id).await;
} }
}, },
Err(e) => { Err(e) => {
@ -347,7 +348,7 @@ pub(super) async fn force_join_room(
let room_id = services().rooms.alias.resolve(&room_id).await?; let room_id = services().rooms.alias.resolve(&room_id).await?;
assert!(service::user_is_local(&user_id), "Parsed user_id must be a local user"); assert!(service::user_is_local(&user_id), "Parsed user_id must be a local user");
join_room_by_id_helper(&user_id, &room_id, None, &[], None).await?; join_room_by_id_helper(services(), &user_id, &room_id, None, &[], None).await?;
Ok(RoomMessageEventContent::notice_markdown(format!( Ok(RoomMessageEventContent::notice_markdown(format!(
"{user_id} has been joined to {room_id}.", "{user_id} has been joined to {room_id}.",

View file

@ -1,5 +1,6 @@
use std::fmt::Write; use std::fmt::Write;
use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::debug_info; use conduit::debug_info;
use register::RegistrationKind; use register::RegistrationKind;
@ -22,7 +23,6 @@ use tracing::{error, info, warn};
use super::{join_room_by_id_helper, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use super::{join_room_by_id_helper, DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
use crate::{ use crate::{
service::user_is_local, service::user_is_local,
services,
utils::{self}, utils::{self},
Error, Result, Ruma, Error, Result, Ruma,
}; };
@ -42,20 +42,21 @@ const RANDOM_USER_ID_LENGTH: usize = 10;
/// invalid when trying to register /// invalid when trying to register
#[tracing::instrument(skip_all, fields(%client), name = "register_available")] #[tracing::instrument(skip_all, fields(%client), name = "register_available")]
pub(crate) async fn get_register_available_route( pub(crate) async fn get_register_available_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_username_availability::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_username_availability::v3::Request>,
) -> Result<get_username_availability::v3::Response> { ) -> Result<get_username_availability::v3::Response> {
// Validate user id // Validate user id
let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name()) let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services.globals.server_name())
.ok() .ok()
.filter(|user_id| !user_id.is_historical() && user_is_local(user_id)) .filter(|user_id| !user_id.is_historical() && user_is_local(user_id))
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
// Check if username is creative enough // Check if username is creative enough
if services().users.exists(&user_id)? { if services.users.exists(&user_id)? {
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
} }
if services() if services
.globals .globals
.forbidden_usernames() .forbidden_usernames()
.is_match(user_id.localpart()) .is_match(user_id.localpart())
@ -91,9 +92,9 @@ pub(crate) async fn get_register_available_route(
#[allow(clippy::doc_markdown)] #[allow(clippy::doc_markdown)]
#[tracing::instrument(skip_all, fields(%client), name = "register")] #[tracing::instrument(skip_all, fields(%client), name = "register")]
pub(crate) async fn register_route( pub(crate) async fn register_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<register::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp, body: Ruma<register::v3::Request>,
) -> Result<register::v3::Response> { ) -> Result<register::v3::Response> {
if !services().globals.allow_registration() && body.appservice_info.is_none() { if !services.globals.allow_registration() && body.appservice_info.is_none() {
info!( info!(
"Registration disabled and request not from known appservice, rejecting registration attempt for username \ "Registration disabled and request not from known appservice, rejecting registration attempt for username \
{:?}", {:?}",
@ -105,8 +106,8 @@ pub(crate) async fn register_route(
let is_guest = body.kind == RegistrationKind::Guest; let is_guest = body.kind == RegistrationKind::Guest;
if is_guest if is_guest
&& (!services().globals.allow_guest_registration() && (!services.globals.allow_guest_registration()
|| (services().globals.allow_registration() && services().globals.config.registration_token.is_some())) || (services.globals.allow_registration() && services.globals.config.registration_token.is_some()))
{ {
info!( info!(
"Guest registration disabled / registration enabled with token configured, rejecting guest registration \ "Guest registration disabled / registration enabled with token configured, rejecting guest registration \
@ -121,7 +122,7 @@ pub(crate) async fn register_route(
// forbid guests from registering if there is not a real admin user yet. give // forbid guests from registering if there is not a real admin user yet. give
// generic user error. // generic user error.
if is_guest && services().users.count()? < 2 { if is_guest && services.users.count()? < 2 {
warn!( warn!(
"Guest account attempted to register before a real admin user has been registered, rejecting \ "Guest account attempted to register before a real admin user has been registered, rejecting \
registration. Guest's initial device name: {:?}", registration. Guest's initial device name: {:?}",
@ -133,16 +134,16 @@ pub(crate) async fn register_route(
let user_id = match (&body.username, is_guest) { let user_id = match (&body.username, is_guest) {
(Some(username), false) => { (Some(username), false) => {
let proposed_user_id = let proposed_user_id =
UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name()) UserId::parse_with_server_name(username.to_lowercase(), services.globals.server_name())
.ok() .ok()
.filter(|user_id| !user_id.is_historical() && user_is_local(user_id)) .filter(|user_id| !user_id.is_historical() && user_is_local(user_id))
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
if services().users.exists(&proposed_user_id)? { if services.users.exists(&proposed_user_id)? {
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
} }
if services() if services
.globals .globals
.forbidden_usernames() .forbidden_usernames()
.is_match(proposed_user_id.localpart()) .is_match(proposed_user_id.localpart())
@ -155,10 +156,10 @@ pub(crate) async fn register_route(
_ => loop { _ => loop {
let proposed_user_id = UserId::parse_with_server_name( let proposed_user_id = UserId::parse_with_server_name(
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
services().globals.server_name(), services.globals.server_name(),
) )
.unwrap(); .unwrap();
if !services().users.exists(&proposed_user_id)? { if !services.users.exists(&proposed_user_id)? {
break proposed_user_id; break proposed_user_id;
} }
}, },
@ -172,13 +173,13 @@ pub(crate) async fn register_route(
} else { } else {
return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing appservice token.")); return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing appservice token."));
} }
} else if services().appservice.is_exclusive_user_id(&user_id).await { } else if services.appservice.is_exclusive_user_id(&user_id).await {
return Err(Error::BadRequest(ErrorKind::Exclusive, "User ID reserved by appservice.")); return Err(Error::BadRequest(ErrorKind::Exclusive, "User ID reserved by appservice."));
} }
// UIAA // UIAA
let mut uiaainfo; let mut uiaainfo;
let skip_auth = if services().globals.config.registration_token.is_some() { let skip_auth = if services.globals.config.registration_token.is_some() {
// Registration token required // Registration token required
uiaainfo = UiaaInfo { uiaainfo = UiaaInfo {
flows: vec![AuthFlow { flows: vec![AuthFlow {
@ -206,8 +207,8 @@ pub(crate) async fn register_route(
if !skip_auth { if !skip_auth {
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services().uiaa.try_auth( let (worked, uiaainfo) = services.uiaa.try_auth(
&UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"), &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"),
"".into(), "".into(),
auth, auth,
&uiaainfo, &uiaainfo,
@ -218,8 +219,8 @@ pub(crate) async fn register_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services().uiaa.create( services.uiaa.create(
&UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"), &UserId::parse_with_server_name("", services.globals.server_name()).expect("we know this is valid"),
"".into(), "".into(),
&uiaainfo, &uiaainfo,
&json, &json,
@ -237,25 +238,25 @@ pub(crate) async fn register_route(
}; };
// Create user // Create user
services().users.create(&user_id, password)?; services.users.create(&user_id, password)?;
// Default to pretty displayname // Default to pretty displayname
let mut displayname = user_id.localpart().to_owned(); let mut displayname = user_id.localpart().to_owned();
// If `new_user_displayname_suffix` is set, registration will push whatever // If `new_user_displayname_suffix` is set, registration will push whatever
// content is set to the user's display name with a space before it // content is set to the user's display name with a space before it
if !services().globals.new_user_displayname_suffix().is_empty() { if !services.globals.new_user_displayname_suffix().is_empty() {
write!(displayname, " {}", services().globals.config.new_user_displayname_suffix) write!(displayname, " {}", services.globals.config.new_user_displayname_suffix)
.expect("should be able to write to string buffer"); .expect("should be able to write to string buffer");
} }
services() services
.users .users
.set_displayname(&user_id, Some(displayname.clone())) .set_displayname(&user_id, Some(displayname.clone()))
.await?; .await?;
// Initial account data // Initial account data
services().account_data.update( services.account_data.update(
None, None,
&user_id, &user_id,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -290,7 +291,7 @@ pub(crate) async fn register_route(
let token = utils::random_string(TOKEN_LENGTH); let token = utils::random_string(TOKEN_LENGTH);
// Create device for this account // Create device for this account
services() services
.users .users
.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?; .create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
@ -299,7 +300,7 @@ pub(crate) async fn register_route(
// log in conduit admin channel if a non-guest user registered // log in conduit admin channel if a non-guest user registered
if body.appservice_info.is_none() && !is_guest { if body.appservice_info.is_none() && !is_guest {
info!("New user \"{user_id}\" registered on this server."); info!("New user \"{user_id}\" registered on this server.");
services() services
.admin .admin
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"New user \"{user_id}\" registered on this server from IP {client}." "New user \"{user_id}\" registered on this server from IP {client}."
@ -308,7 +309,7 @@ pub(crate) async fn register_route(
} }
// log in conduit admin channel if a guest registered // log in conduit admin channel if a guest registered
if body.appservice_info.is_none() && is_guest && services().globals.log_guest_registrations() { if body.appservice_info.is_none() && is_guest && services.globals.log_guest_registrations() {
info!("New guest user \"{user_id}\" registered on this server."); info!("New guest user \"{user_id}\" registered on this server.");
if let Some(device_display_name) = &body.initial_device_display_name { if let Some(device_display_name) = &body.initial_device_display_name {
@ -317,7 +318,7 @@ pub(crate) async fn register_route(
.as_ref() .as_ref()
.is_some_and(|device_display_name| !device_display_name.is_empty()) .is_some_and(|device_display_name| !device_display_name.is_empty())
{ {
services() services
.admin .admin
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"Guest user \"{user_id}\" with device display name `{device_display_name}` registered on this \ "Guest user \"{user_id}\" with device display name `{device_display_name}` registered on this \
@ -325,7 +326,7 @@ pub(crate) async fn register_route(
))) )))
.await; .await;
} else { } else {
services() services
.admin .admin
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"Guest user \"{user_id}\" with no device display name registered on this server from IP \ "Guest user \"{user_id}\" with no device display name registered on this server from IP \
@ -334,7 +335,7 @@ pub(crate) async fn register_route(
.await; .await;
} }
} else { } else {
services() services
.admin .admin
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"Guest user \"{user_id}\" with no device display name registered on this server from IP {client}.", "Guest user \"{user_id}\" with no device display name registered on this server from IP {client}.",
@ -347,12 +348,7 @@ pub(crate) async fn register_route(
// users Note: the server user, @conduit:servername, is generated first // users Note: the server user, @conduit:servername, is generated first
if !is_guest { if !is_guest {
if let Some(admin_room) = service::admin::Service::get_admin_room()? { if let Some(admin_room) = service::admin::Service::get_admin_room()? {
if services() if services.rooms.state_cache.room_joined_count(&admin_room)? == Some(1) {
.rooms
.state_cache
.room_joined_count(&admin_room)?
== Some(1)
{
service::admin::make_user_admin(&user_id, displayname).await?; service::admin::make_user_admin(&user_id, displayname).await?;
warn!("Granting {user_id} admin privileges as the first user"); warn!("Granting {user_id} admin privileges as the first user");
@ -361,14 +357,14 @@ pub(crate) async fn register_route(
} }
if body.appservice_info.is_none() if body.appservice_info.is_none()
&& !services().globals.config.auto_join_rooms.is_empty() && !services.globals.config.auto_join_rooms.is_empty()
&& (services().globals.allow_guests_auto_join_rooms() || !is_guest) && (services.globals.allow_guests_auto_join_rooms() || !is_guest)
{ {
for room in &services().globals.config.auto_join_rooms { for room in &services.globals.config.auto_join_rooms {
if !services() if !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(services().globals.server_name(), room)? .server_in_room(services.globals.server_name(), room)?
{ {
warn!("Skipping room {room} to automatically join as we have never joined before."); warn!("Skipping room {room} to automatically join as we have never joined before.");
continue; continue;
@ -376,10 +372,11 @@ pub(crate) async fn register_route(
if let Some(room_id_server_name) = room.server_name() { if let Some(room_id_server_name) = room.server_name() {
if let Err(e) = join_room_by_id_helper( if let Err(e) = join_room_by_id_helper(
services,
&user_id, &user_id,
room, room,
Some("Automatically joining this room upon registration".to_owned()), Some("Automatically joining this room upon registration".to_owned()),
&[room_id_server_name.to_owned(), services().globals.server_name().to_owned()], &[room_id_server_name.to_owned(), services.globals.server_name().to_owned()],
None, None,
) )
.await .await
@ -421,7 +418,8 @@ pub(crate) async fn register_route(
/// - Triggers device list updates /// - Triggers device list updates
#[tracing::instrument(skip_all, fields(%client), name = "change_password")] #[tracing::instrument(skip_all, fields(%client), name = "change_password")]
pub(crate) async fn change_password_route( pub(crate) async fn change_password_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<change_password::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<change_password::v3::Request>,
) -> Result<change_password::v3::Response> { ) -> Result<change_password::v3::Response> {
// Authentication for this endpoint was made optional, but we need // Authentication for this endpoint was made optional, but we need
// authentication currently // authentication currently
@ -442,7 +440,7 @@ pub(crate) async fn change_password_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services() let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; .try_auth(sender_user, sender_device, auth, &uiaainfo)?;
if !worked { if !worked {
@ -451,7 +449,7 @@ pub(crate) async fn change_password_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services() services
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -459,24 +457,24 @@ pub(crate) async fn change_password_route(
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }
services() services
.users .users
.set_password(sender_user, Some(&body.new_password))?; .set_password(sender_user, Some(&body.new_password))?;
if body.logout_devices { if body.logout_devices {
// Logout all devices except the current one // Logout all devices except the current one
for id in services() for id in services
.users .users
.all_device_ids(sender_user) .all_device_ids(sender_user)
.filter_map(Result::ok) .filter_map(Result::ok)
.filter(|id| id != sender_device) .filter(|id| id != sender_device)
{ {
services().users.remove_device(sender_user, &id)?; services.users.remove_device(sender_user, &id)?;
} }
} }
info!("User {sender_user} changed their password."); info!("User {sender_user} changed their password.");
services() services
.admin .admin
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"User {sender_user} changed their password." "User {sender_user} changed their password."
@ -491,14 +489,16 @@ pub(crate) async fn change_password_route(
/// Get `user_id` of the sender user. /// Get `user_id` of the sender user.
/// ///
/// Note: Also works for Application Services /// Note: Also works for Application Services
pub(crate) async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3::Response> { pub(crate) async fn whoami_route(
State(services): State<crate::State>, body: Ruma<whoami::v3::Request>,
) -> Result<whoami::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let device_id = body.sender_device.clone(); let device_id = body.sender_device.clone();
Ok(whoami::v3::Response { Ok(whoami::v3::Response {
user_id: sender_user.clone(), user_id: sender_user.clone(),
device_id, device_id,
is_guest: services().users.is_deactivated(sender_user)? && body.appservice_info.is_none(), is_guest: services.users.is_deactivated(sender_user)? && body.appservice_info.is_none(),
}) })
} }
@ -515,7 +515,8 @@ pub(crate) async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoa
/// - Removes ability to log in again /// - Removes ability to log in again
#[tracing::instrument(skip_all, fields(%client), name = "deactivate")] #[tracing::instrument(skip_all, fields(%client), name = "deactivate")]
pub(crate) async fn deactivate_route( pub(crate) async fn deactivate_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<deactivate::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<deactivate::v3::Request>,
) -> Result<deactivate::v3::Response> { ) -> Result<deactivate::v3::Response> {
// Authentication for this endpoint was made optional, but we need // Authentication for this endpoint was made optional, but we need
// authentication currently // authentication currently
@ -536,7 +537,7 @@ pub(crate) async fn deactivate_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services() let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; .try_auth(sender_user, sender_device, auth, &uiaainfo)?;
if !worked { if !worked {
@ -545,7 +546,7 @@ pub(crate) async fn deactivate_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services() services
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -554,23 +555,23 @@ pub(crate) async fn deactivate_route(
} }
// Remove devices and mark account as deactivated // Remove devices and mark account as deactivated
services().users.deactivate_account(sender_user)?; services.users.deactivate_account(sender_user)?;
// Remove profile pictures and display name // Remove profile pictures and display name
let all_joined_rooms: Vec<OwnedRoomId> = services() let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(Result::ok) .filter_map(Result::ok)
.collect(); .collect();
super::update_displayname(sender_user.clone(), None, all_joined_rooms.clone()).await?; super::update_displayname(services, sender_user.clone(), None, all_joined_rooms.clone()).await?;
super::update_avatar_url(sender_user.clone(), None, None, all_joined_rooms).await?; super::update_avatar_url(services, sender_user.clone(), None, None, all_joined_rooms).await?;
// Make the user leave all rooms before deactivation // Make the user leave all rooms before deactivation
super::leave_all_rooms(sender_user).await; super::leave_all_rooms(services, sender_user).await;
info!("User {sender_user} deactivated their account."); info!("User {sender_user} deactivated their account.");
services() services
.admin .admin
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"User {sender_user} deactivated their account." "User {sender_user} deactivated their account."
@ -632,9 +633,9 @@ pub(crate) async fn request_3pid_management_token_via_msisdn_route(
/// Currently does not have any ratelimiting, and this isn't very practical as /// Currently does not have any ratelimiting, and this isn't very practical as
/// there is only one registration token allowed. /// there is only one registration token allowed.
pub(crate) async fn check_registration_token_validity( pub(crate) async fn check_registration_token_validity(
body: Ruma<check_registration_token_validity::v1::Request>, State(services): State<crate::State>, body: Ruma<check_registration_token_validity::v1::Request>,
) -> Result<check_registration_token_validity::v1::Response> { ) -> Result<check_registration_token_validity::v1::Response> {
let Some(reg_token) = services().globals.config.registration_token.clone() else { let Some(reg_token) = services.globals.config.registration_token.clone() else {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"Server does not allow token registration.", "Server does not allow token registration.",

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -8,19 +9,24 @@ use ruma::{
}; };
use tracing::debug; use tracing::debug;
use crate::{service::server_is_ours, services, Error, Result, Ruma}; use crate::{
service::{server_is_ours, Services},
Error, Result, Ruma,
};
/// # `PUT /_matrix/client/v3/directory/room/{roomAlias}` /// # `PUT /_matrix/client/v3/directory/room/{roomAlias}`
/// ///
/// Creates a new room alias on this server. /// Creates a new room alias on this server.
pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> Result<create_alias::v3::Response> { pub(crate) async fn create_alias_route(
State(services): State<crate::State>, body: Ruma<create_alias::v3::Request>,
) -> Result<create_alias::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?; service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?;
// this isn't apart of alias_checks or delete alias route because we should // this isn't apart of alias_checks or delete alias route because we should
// allow removing forbidden room aliases // allow removing forbidden room aliases
if services() if services
.globals .globals
.forbidden_alias_names() .forbidden_alias_names()
.is_match(body.room_alias.alias()) .is_match(body.room_alias.alias())
@ -28,7 +34,7 @@ pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) ->
return Err(Error::BadRequest(ErrorKind::forbidden(), "Room alias is forbidden.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Room alias is forbidden."));
} }
if services() if services
.rooms .rooms
.alias .alias
.resolve_local_alias(&body.room_alias)? .resolve_local_alias(&body.room_alias)?
@ -37,7 +43,7 @@ pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) ->
return Err(Error::Conflict("Alias already exists.")); return Err(Error::Conflict("Alias already exists."));
} }
services() services
.rooms .rooms
.alias .alias
.set_alias(&body.room_alias, &body.room_id, sender_user)?; .set_alias(&body.room_alias, &body.room_id, sender_user)?;
@ -50,12 +56,14 @@ pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) ->
/// Deletes a room alias from this server. /// Deletes a room alias from this server.
/// ///
/// - TODO: Update canonical alias event /// - TODO: Update canonical alias event
pub(crate) async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> Result<delete_alias::v3::Response> { pub(crate) async fn delete_alias_route(
State(services): State<crate::State>, body: Ruma<delete_alias::v3::Request>,
) -> Result<delete_alias::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?; service::rooms::alias::appservice_checks(&body.room_alias, &body.appservice_info).await?;
if services() if services
.rooms .rooms
.alias .alias
.resolve_local_alias(&body.room_alias)? .resolve_local_alias(&body.room_alias)?
@ -64,7 +72,7 @@ pub(crate) async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) ->
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
} }
services() services
.rooms .rooms
.alias .alias
.remove_alias(&body.room_alias, sender_user) .remove_alias(&body.room_alias, sender_user)
@ -78,11 +86,13 @@ pub(crate) async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) ->
/// # `GET /_matrix/client/v3/directory/room/{roomAlias}` /// # `GET /_matrix/client/v3/directory/room/{roomAlias}`
/// ///
/// Resolve an alias locally or over federation. /// Resolve an alias locally or over federation.
pub(crate) async fn get_alias_route(body: Ruma<get_alias::v3::Request>) -> Result<get_alias::v3::Response> { pub(crate) async fn get_alias_route(
State(services): State<crate::State>, body: Ruma<get_alias::v3::Request>,
) -> Result<get_alias::v3::Response> {
let room_alias = body.body.room_alias; let room_alias = body.body.room_alias;
let servers = None; let servers = None;
let Ok((room_id, pre_servers)) = services() let Ok((room_id, pre_servers)) = services
.rooms .rooms
.alias .alias
.resolve_alias(&room_alias, servers.as_ref()) .resolve_alias(&room_alias, servers.as_ref())
@ -91,17 +101,17 @@ pub(crate) async fn get_alias_route(body: Ruma<get_alias::v3::Request>) -> Resul
return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found."));
}; };
let servers = room_available_servers(&room_id, &room_alias, &pre_servers); let servers = room_available_servers(services, &room_id, &room_alias, &pre_servers);
debug!(?room_alias, ?room_id, "available servers: {servers:?}"); debug!(?room_alias, ?room_id, "available servers: {servers:?}");
Ok(get_alias::v3::Response::new(room_id, servers)) Ok(get_alias::v3::Response::new(room_id, servers))
} }
fn room_available_servers( fn room_available_servers(
room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option<Vec<OwnedServerName>>, services: &Services, room_id: &RoomId, room_alias: &RoomAliasId, pre_servers: &Option<Vec<OwnedServerName>>,
) -> Vec<OwnedServerName> { ) -> Vec<OwnedServerName> {
// find active servers in room state cache to suggest // find active servers in room state cache to suggest
let mut servers: Vec<OwnedServerName> = services() let mut servers: Vec<OwnedServerName> = services
.rooms .rooms
.state_cache .state_cache
.room_servers(room_id) .room_servers(room_id)
@ -127,7 +137,7 @@ fn room_available_servers(
.position(|server_name| server_is_ours(server_name)) .position(|server_name| server_is_ours(server_name))
{ {
servers.swap_remove(server_index); servers.swap_remove(server_index);
servers.insert(0, services().globals.server_name().to_owned()); servers.insert(0, services.globals.server_name().to_owned());
} else if let Some(alias_server_index) = servers } else if let Some(alias_server_index) = servers
.iter() .iter()
.position(|server| server == room_alias.server_name()) .position(|server| server == room_alias.server_name())

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use ruma::{ use ruma::{
api::client::{ api::client::{
backup::{ backup::{
@ -11,16 +12,16 @@ use ruma::{
UInt, UInt,
}; };
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `POST /_matrix/client/r0/room_keys/version` /// # `POST /_matrix/client/r0/room_keys/version`
/// ///
/// Creates a new backup. /// Creates a new backup.
pub(crate) async fn create_backup_version_route( pub(crate) async fn create_backup_version_route(
body: Ruma<create_backup_version::v3::Request>, State(services): State<crate::State>, body: Ruma<create_backup_version::v3::Request>,
) -> Result<create_backup_version::v3::Response> { ) -> Result<create_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let version = services() let version = services
.key_backups .key_backups
.create_backup(sender_user, &body.algorithm)?; .create_backup(sender_user, &body.algorithm)?;
@ -34,10 +35,10 @@ pub(crate) async fn create_backup_version_route(
/// Update information about an existing backup. Only `auth_data` can be /// Update information about an existing backup. Only `auth_data` can be
/// modified. /// modified.
pub(crate) async fn update_backup_version_route( pub(crate) async fn update_backup_version_route(
body: Ruma<update_backup_version::v3::Request>, State(services): State<crate::State>, body: Ruma<update_backup_version::v3::Request>,
) -> Result<update_backup_version::v3::Response> { ) -> Result<update_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services
.key_backups .key_backups
.update_backup(sender_user, &body.version, &body.algorithm)?; .update_backup(sender_user, &body.version, &body.algorithm)?;
@ -48,20 +49,20 @@ pub(crate) async fn update_backup_version_route(
/// ///
/// Get information about the latest backup version. /// Get information about the latest backup version.
pub(crate) async fn get_latest_backup_info_route( pub(crate) async fn get_latest_backup_info_route(
body: Ruma<get_latest_backup_info::v3::Request>, State(services): State<crate::State>, body: Ruma<get_latest_backup_info::v3::Request>,
) -> Result<get_latest_backup_info::v3::Response> { ) -> Result<get_latest_backup_info::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let (version, algorithm) = services() let (version, algorithm) = services
.key_backups .key_backups
.get_latest_backup(sender_user)? .get_latest_backup(sender_user)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
Ok(get_latest_backup_info::v3::Response { Ok(get_latest_backup_info::v3::Response {
algorithm, algorithm,
count: (UInt::try_from(services().key_backups.count_keys(sender_user, &version)?) count: (UInt::try_from(services.key_backups.count_keys(sender_user, &version)?)
.expect("user backup keys count should not be that high")), .expect("user backup keys count should not be that high")),
etag: services().key_backups.get_etag(sender_user, &version)?, etag: services.key_backups.get_etag(sender_user, &version)?,
version, version,
}) })
} }
@ -70,10 +71,10 @@ pub(crate) async fn get_latest_backup_info_route(
/// ///
/// Get information about an existing backup. /// Get information about an existing backup.
pub(crate) async fn get_backup_info_route( pub(crate) async fn get_backup_info_route(
body: Ruma<get_backup_info::v3::Request>, State(services): State<crate::State>, body: Ruma<get_backup_info::v3::Request>,
) -> Result<get_backup_info::v3::Response> { ) -> Result<get_backup_info::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let algorithm = services() let algorithm = services
.key_backups .key_backups
.get_backup(sender_user, &body.version)? .get_backup(sender_user, &body.version)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
@ -81,14 +82,12 @@ pub(crate) async fn get_backup_info_route(
Ok(get_backup_info::v3::Response { Ok(get_backup_info::v3::Response {
algorithm, algorithm,
count: (UInt::try_from( count: (UInt::try_from(
services() services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)?,
) )
.expect("user backup keys count should not be that high")), .expect("user backup keys count should not be that high")),
etag: services() etag: services.key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
version: body.version.clone(), version: body.version.clone(),
}) })
} }
@ -100,11 +99,11 @@ pub(crate) async fn get_backup_info_route(
/// - Deletes both information about the backup, as well as all key data related /// - Deletes both information about the backup, as well as all key data related
/// to the backup /// to the backup
pub(crate) async fn delete_backup_version_route( pub(crate) async fn delete_backup_version_route(
body: Ruma<delete_backup_version::v3::Request>, State(services): State<crate::State>, body: Ruma<delete_backup_version::v3::Request>,
) -> Result<delete_backup_version::v3::Response> { ) -> Result<delete_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services
.key_backups .key_backups
.delete_backup(sender_user, &body.version)?; .delete_backup(sender_user, &body.version)?;
@ -120,12 +119,12 @@ pub(crate) async fn delete_backup_version_route(
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub(crate) async fn add_backup_keys_route( pub(crate) async fn add_backup_keys_route(
body: Ruma<add_backup_keys::v3::Request>, State(services): State<crate::State>, body: Ruma<add_backup_keys::v3::Request>,
) -> Result<add_backup_keys::v3::Response> { ) -> Result<add_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if Some(&body.version)
!= services() != services
.key_backups .key_backups
.get_latest_backup_version(sender_user)? .get_latest_backup_version(sender_user)?
.as_ref() .as_ref()
@ -138,7 +137,7 @@ pub(crate) async fn add_backup_keys_route(
for (room_id, room) in &body.rooms { for (room_id, room) in &body.rooms {
for (session_id, key_data) in &room.sessions { for (session_id, key_data) in &room.sessions {
services() services
.key_backups .key_backups
.add_key(sender_user, &body.version, room_id, session_id, key_data)?; .add_key(sender_user, &body.version, room_id, session_id, key_data)?;
} }
@ -146,14 +145,12 @@ pub(crate) async fn add_backup_keys_route(
Ok(add_backup_keys::v3::Response { Ok(add_backup_keys::v3::Response {
count: (UInt::try_from( count: (UInt::try_from(
services() services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)?,
) )
.expect("user backup keys count should not be that high")), .expect("user backup keys count should not be that high")),
etag: services() etag: services.key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -166,12 +163,12 @@ pub(crate) async fn add_backup_keys_route(
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub(crate) async fn add_backup_keys_for_room_route( pub(crate) async fn add_backup_keys_for_room_route(
body: Ruma<add_backup_keys_for_room::v3::Request>, State(services): State<crate::State>, body: Ruma<add_backup_keys_for_room::v3::Request>,
) -> Result<add_backup_keys_for_room::v3::Response> { ) -> Result<add_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if Some(&body.version)
!= services() != services
.key_backups .key_backups
.get_latest_backup_version(sender_user)? .get_latest_backup_version(sender_user)?
.as_ref() .as_ref()
@ -183,21 +180,19 @@ pub(crate) async fn add_backup_keys_for_room_route(
} }
for (session_id, key_data) in &body.sessions { for (session_id, key_data) in &body.sessions {
services() services
.key_backups .key_backups
.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; .add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?;
} }
Ok(add_backup_keys_for_room::v3::Response { Ok(add_backup_keys_for_room::v3::Response {
count: (UInt::try_from( count: (UInt::try_from(
services() services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)?,
) )
.expect("user backup keys count should not be that high")), .expect("user backup keys count should not be that high")),
etag: services() etag: services.key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -210,12 +205,12 @@ pub(crate) async fn add_backup_keys_for_room_route(
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub(crate) async fn add_backup_keys_for_session_route( pub(crate) async fn add_backup_keys_for_session_route(
body: Ruma<add_backup_keys_for_session::v3::Request>, State(services): State<crate::State>, body: Ruma<add_backup_keys_for_session::v3::Request>,
) -> Result<add_backup_keys_for_session::v3::Response> { ) -> Result<add_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if Some(&body.version)
!= services() != services
.key_backups .key_backups
.get_latest_backup_version(sender_user)? .get_latest_backup_version(sender_user)?
.as_ref() .as_ref()
@ -226,20 +221,18 @@ pub(crate) async fn add_backup_keys_for_session_route(
)); ));
} }
services() services
.key_backups .key_backups
.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; .add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?;
Ok(add_backup_keys_for_session::v3::Response { Ok(add_backup_keys_for_session::v3::Response {
count: (UInt::try_from( count: (UInt::try_from(
services() services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)?,
) )
.expect("user backup keys count should not be that high")), .expect("user backup keys count should not be that high")),
etag: services() etag: services.key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -247,11 +240,11 @@ pub(crate) async fn add_backup_keys_for_session_route(
/// ///
/// Retrieves all keys from the backup. /// Retrieves all keys from the backup.
pub(crate) async fn get_backup_keys_route( pub(crate) async fn get_backup_keys_route(
body: Ruma<get_backup_keys::v3::Request>, State(services): State<crate::State>, body: Ruma<get_backup_keys::v3::Request>,
) -> Result<get_backup_keys::v3::Response> { ) -> Result<get_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let rooms = services().key_backups.get_all(sender_user, &body.version)?; let rooms = services.key_backups.get_all(sender_user, &body.version)?;
Ok(get_backup_keys::v3::Response { Ok(get_backup_keys::v3::Response {
rooms, rooms,
@ -262,11 +255,11 @@ pub(crate) async fn get_backup_keys_route(
/// ///
/// Retrieves all keys from the backup for a given room. /// Retrieves all keys from the backup for a given room.
pub(crate) async fn get_backup_keys_for_room_route( pub(crate) async fn get_backup_keys_for_room_route(
body: Ruma<get_backup_keys_for_room::v3::Request>, State(services): State<crate::State>, body: Ruma<get_backup_keys_for_room::v3::Request>,
) -> Result<get_backup_keys_for_room::v3::Response> { ) -> Result<get_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sessions = services() let sessions = services
.key_backups .key_backups
.get_room(sender_user, &body.version, &body.room_id)?; .get_room(sender_user, &body.version, &body.room_id)?;
@ -279,11 +272,11 @@ pub(crate) async fn get_backup_keys_for_room_route(
/// ///
/// Retrieves a key from the backup. /// Retrieves a key from the backup.
pub(crate) async fn get_backup_keys_for_session_route( pub(crate) async fn get_backup_keys_for_session_route(
body: Ruma<get_backup_keys_for_session::v3::Request>, State(services): State<crate::State>, body: Ruma<get_backup_keys_for_session::v3::Request>,
) -> Result<get_backup_keys_for_session::v3::Response> { ) -> Result<get_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let key_data = services() let key_data = services
.key_backups .key_backups
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)? .get_session(sender_user, &body.version, &body.room_id, &body.session_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."))?;
@ -297,24 +290,22 @@ pub(crate) async fn get_backup_keys_for_session_route(
/// ///
/// Delete the keys from the backup. /// Delete the keys from the backup.
pub(crate) async fn delete_backup_keys_route( pub(crate) async fn delete_backup_keys_route(
body: Ruma<delete_backup_keys::v3::Request>, State(services): State<crate::State>, body: Ruma<delete_backup_keys::v3::Request>,
) -> Result<delete_backup_keys::v3::Response> { ) -> Result<delete_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services
.key_backups .key_backups
.delete_all_keys(sender_user, &body.version)?; .delete_all_keys(sender_user, &body.version)?;
Ok(delete_backup_keys::v3::Response { Ok(delete_backup_keys::v3::Response {
count: (UInt::try_from( count: (UInt::try_from(
services() services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)?,
) )
.expect("user backup keys count should not be that high")), .expect("user backup keys count should not be that high")),
etag: services() etag: services.key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -322,24 +313,22 @@ pub(crate) async fn delete_backup_keys_route(
/// ///
/// Delete the keys from the backup for a given room. /// Delete the keys from the backup for a given room.
pub(crate) async fn delete_backup_keys_for_room_route( pub(crate) async fn delete_backup_keys_for_room_route(
body: Ruma<delete_backup_keys_for_room::v3::Request>, State(services): State<crate::State>, body: Ruma<delete_backup_keys_for_room::v3::Request>,
) -> Result<delete_backup_keys_for_room::v3::Response> { ) -> Result<delete_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services
.key_backups .key_backups
.delete_room_keys(sender_user, &body.version, &body.room_id)?; .delete_room_keys(sender_user, &body.version, &body.room_id)?;
Ok(delete_backup_keys_for_room::v3::Response { Ok(delete_backup_keys_for_room::v3::Response {
count: (UInt::try_from( count: (UInt::try_from(
services() services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)?,
) )
.expect("user backup keys count should not be that high")), .expect("user backup keys count should not be that high")),
etag: services() etag: services.key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }
@ -347,23 +336,21 @@ pub(crate) async fn delete_backup_keys_for_room_route(
/// ///
/// Delete a key from the backup. /// Delete a key from the backup.
pub(crate) async fn delete_backup_keys_for_session_route( pub(crate) async fn delete_backup_keys_for_session_route(
body: Ruma<delete_backup_keys_for_session::v3::Request>, State(services): State<crate::State>, body: Ruma<delete_backup_keys_for_session::v3::Request>,
) -> Result<delete_backup_keys_for_session::v3::Response> { ) -> Result<delete_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services
.key_backups .key_backups
.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?;
Ok(delete_backup_keys_for_session::v3::Response { Ok(delete_backup_keys_for_session::v3::Response {
count: (UInt::try_from( count: (UInt::try_from(
services() services
.key_backups .key_backups
.count_keys(sender_user, &body.version)?, .count_keys(sender_user, &body.version)?,
) )
.expect("user backup keys count should not be that high")), .expect("user backup keys count should not be that high")),
etag: services() etag: services.key_backups.get_etag(sender_user, &body.version)?,
.key_backups
.get_etag(sender_user, &body.version)?,
}) })
} }

View file

@ -1,29 +1,30 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State;
use ruma::api::client::discovery::get_capabilities::{ use ruma::api::client::discovery::get_capabilities::{
self, Capabilities, RoomVersionStability, RoomVersionsCapability, ThirdPartyIdChangesCapability, self, Capabilities, RoomVersionStability, RoomVersionsCapability, ThirdPartyIdChangesCapability,
}; };
use crate::{services, Result, Ruma}; use crate::{Result, Ruma};
/// # `GET /_matrix/client/v3/capabilities` /// # `GET /_matrix/client/v3/capabilities`
/// ///
/// Get information on the supported feature set and other relevent capabilities /// Get information on the supported feature set and other relevent capabilities
/// of this server. /// of this server.
pub(crate) async fn get_capabilities_route( pub(crate) async fn get_capabilities_route(
_body: Ruma<get_capabilities::v3::Request>, State(services): State<crate::State>, _body: Ruma<get_capabilities::v3::Request>,
) -> Result<get_capabilities::v3::Response> { ) -> Result<get_capabilities::v3::Response> {
let mut available = BTreeMap::new(); let mut available = BTreeMap::new();
for room_version in &services().globals.unstable_room_versions { for room_version in &services.globals.unstable_room_versions {
available.insert(room_version.clone(), RoomVersionStability::Unstable); available.insert(room_version.clone(), RoomVersionStability::Unstable);
} }
for room_version in &services().globals.stable_room_versions { for room_version in &services.globals.stable_room_versions {
available.insert(room_version.clone(), RoomVersionStability::Stable); available.insert(room_version.clone(), RoomVersionStability::Stable);
} }
let mut capabilities = Capabilities::default(); let mut capabilities = Capabilities::default();
capabilities.room_versions = RoomVersionsCapability { capabilities.room_versions = RoomVersionsCapability {
default: services().globals.default_room_version(), default: services.globals.default_room_version(),
available, available,
}; };

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use ruma::{ use ruma::{
api::client::{ api::client::{
config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data}, config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data},
@ -10,15 +11,21 @@ use ruma::{
use serde::Deserialize; use serde::Deserialize;
use serde_json::{json, value::RawValue as RawJsonValue}; use serde_json::{json, value::RawValue as RawJsonValue};
use crate::{services, Error, Result, Ruma}; use crate::{service::Services, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}` /// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}`
/// ///
/// Sets some account data for the sender user. /// Sets some account data for the sender user.
pub(crate) async fn set_global_account_data_route( pub(crate) async fn set_global_account_data_route(
body: Ruma<set_global_account_data::v3::Request>, State(services): State<crate::State>, body: Ruma<set_global_account_data::v3::Request>,
) -> Result<set_global_account_data::v3::Response> { ) -> Result<set_global_account_data::v3::Response> {
set_account_data(None, &body.sender_user, &body.event_type.to_string(), body.data.json())?; set_account_data(
services,
None,
&body.sender_user,
&body.event_type.to_string(),
body.data.json(),
)?;
Ok(set_global_account_data::v3::Response {}) Ok(set_global_account_data::v3::Response {})
} }
@ -27,9 +34,10 @@ pub(crate) async fn set_global_account_data_route(
/// ///
/// Sets some room account data for the sender user. /// Sets some room account data for the sender user.
pub(crate) async fn set_room_account_data_route( pub(crate) async fn set_room_account_data_route(
body: Ruma<set_room_account_data::v3::Request>, State(services): State<crate::State>, body: Ruma<set_room_account_data::v3::Request>,
) -> Result<set_room_account_data::v3::Response> { ) -> Result<set_room_account_data::v3::Response> {
set_account_data( set_account_data(
services,
Some(&body.room_id), Some(&body.room_id),
&body.sender_user, &body.sender_user,
&body.event_type.to_string(), &body.event_type.to_string(),
@ -43,11 +51,11 @@ pub(crate) async fn set_room_account_data_route(
/// ///
/// Gets some account data for the sender user. /// Gets some account data for the sender user.
pub(crate) async fn get_global_account_data_route( pub(crate) async fn get_global_account_data_route(
body: Ruma<get_global_account_data::v3::Request>, State(services): State<crate::State>, body: Ruma<get_global_account_data::v3::Request>,
) -> Result<get_global_account_data::v3::Response> { ) -> Result<get_global_account_data::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event: Box<RawJsonValue> = services() let event: Box<RawJsonValue> = services
.account_data .account_data
.get(None, sender_user, body.event_type.to_string().into())? .get(None, sender_user, body.event_type.to_string().into())?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
@ -65,11 +73,11 @@ pub(crate) async fn get_global_account_data_route(
/// ///
/// Gets some room account data for the sender user. /// Gets some room account data for the sender user.
pub(crate) async fn get_room_account_data_route( pub(crate) async fn get_room_account_data_route(
body: Ruma<get_room_account_data::v3::Request>, State(services): State<crate::State>, body: Ruma<get_room_account_data::v3::Request>,
) -> Result<get_room_account_data::v3::Response> { ) -> Result<get_room_account_data::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event: Box<RawJsonValue> = services() let event: Box<RawJsonValue> = services
.account_data .account_data
.get(Some(&body.room_id), sender_user, body.event_type.clone())? .get(Some(&body.room_id), sender_user, body.event_type.clone())?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
@ -84,14 +92,15 @@ pub(crate) async fn get_room_account_data_route(
} }
fn set_account_data( fn set_account_data(
room_id: Option<&RoomId>, sender_user: &Option<OwnedUserId>, event_type: &str, data: &RawJsonValue, services: &Services, room_id: Option<&RoomId>, sender_user: &Option<OwnedUserId>, event_type: &str,
data: &RawJsonValue,
) -> Result<()> { ) -> Result<()> {
let sender_user = sender_user.as_ref().expect("user is authenticated"); let sender_user = sender_user.as_ref().expect("user is authenticated");
let data: serde_json::Value = let data: serde_json::Value =
serde_json::from_str(data.get()).map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; serde_json::from_str(data.get()).map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?;
services().account_data.update( services.account_data.update(
room_id, room_id,
sender_user, sender_user,
event_type.into(), event_type.into(),

View file

@ -1,12 +1,13 @@
use std::collections::HashSet; use std::collections::HashSet;
use axum::extract::State;
use ruma::{ use ruma::{
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
events::StateEventType, events::StateEventType,
}; };
use tracing::error; use tracing::error;
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/context` /// # `GET /_matrix/client/r0/rooms/{roomId}/context`
/// ///
@ -14,7 +15,9 @@ use crate::{services, Error, Result, Ruma};
/// ///
/// - Only works if the user is joined (TODO: always allow, but only show events /// - Only works if the user is joined (TODO: always allow, but only show events
/// if the user was joined, depending on history_visibility) /// if the user was joined, depending on history_visibility)
pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> Result<get_context::v3::Response> { pub(crate) async fn get_context_route(
State(services): State<crate::State>, body: Ruma<get_context::v3::Request>,
) -> Result<get_context::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
@ -27,13 +30,13 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
let base_token = services() let base_token = services
.rooms .rooms
.timeline .timeline
.get_pdu_count(&body.event_id)? .get_pdu_count(&body.event_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?;
let base_event = services() let base_event = services
.rooms .rooms
.timeline .timeline
.get_pdu(&body.event_id)? .get_pdu(&body.event_id)?
@ -41,7 +44,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
let room_id = base_event.room_id.clone(); let room_id = base_event.room_id.clone();
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &room_id, &body.event_id)? .user_can_see_event(sender_user, &room_id, &body.event_id)?
@ -52,7 +55,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
)); ));
} }
if !services().rooms.lazy_loading.lazy_load_was_sent_before( if !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&room_id, &room_id,
@ -67,14 +70,14 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
let base_event = base_event.to_room_event(); let base_event = base_event.to_room_event();
let events_before: Vec<_> = services() let events_before: Vec<_> = services
.rooms .rooms
.timeline .timeline
.pdus_until(sender_user, &room_id, base_token)? .pdus_until(sender_user, &room_id, base_token)?
.take(limit / 2) .take(limit / 2)
.filter_map(Result::ok) // Remove buggy events .filter_map(Result::ok) // Remove buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &room_id, &pdu.event_id) .user_can_see_event(sender_user, &room_id, &pdu.event_id)
@ -83,7 +86,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
.collect(); .collect();
for (_, event) in &events_before { for (_, event) in &events_before {
if !services().rooms.lazy_loading.lazy_load_was_sent_before( if !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&room_id, &room_id,
@ -103,14 +106,14 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
.map(|(_, pdu)| pdu.to_room_event()) .map(|(_, pdu)| pdu.to_room_event())
.collect(); .collect();
let events_after: Vec<_> = services() let events_after: Vec<_> = services
.rooms .rooms
.timeline .timeline
.pdus_after(sender_user, &room_id, base_token)? .pdus_after(sender_user, &room_id, base_token)?
.take(limit / 2) .take(limit / 2)
.filter_map(Result::ok) // Remove buggy events .filter_map(Result::ok) // Remove buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &room_id, &pdu.event_id) .user_can_see_event(sender_user, &room_id, &pdu.event_id)
@ -119,7 +122,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
.collect(); .collect();
for (_, event) in &events_after { for (_, event) in &events_after {
if !services().rooms.lazy_loading.lazy_load_was_sent_before( if !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&room_id, &room_id,
@ -130,7 +133,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
} }
} }
let shortstatehash = services() let shortstatehash = services
.rooms .rooms
.state_accessor .state_accessor
.pdu_shortstatehash( .pdu_shortstatehash(
@ -139,7 +142,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
.map_or(&*body.event_id, |(_, e)| &*e.event_id), .map_or(&*body.event_id, |(_, e)| &*e.event_id),
)? )?
.map_or( .map_or(
services() services
.rooms .rooms
.state .state
.get_room_shortstatehash(&room_id)? .get_room_shortstatehash(&room_id)?
@ -147,7 +150,7 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
|hash| hash, |hash| hash,
); );
let state_ids = services() let state_ids = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
@ -165,20 +168,20 @@ pub(crate) async fn get_context_route(body: Ruma<get_context::v3::Request>) -> R
let mut state = Vec::with_capacity(state_ids.len()); let mut state = Vec::with_capacity(state_ids.len());
for (shortstatekey, id) in state_ids { for (shortstatekey, id) in state_ids {
let (event_type, state_key) = services() let (event_type, state_key) = services
.rooms .rooms
.short .short
.get_statekey_from_short(shortstatekey)?; .get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember { if event_type != StateEventType::RoomMember {
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
continue; continue;
}; };
state.push(pdu.to_state_event()); state.push(pdu.to_state_event());
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
continue; continue;
}; };

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use ruma::api::client::{ use ruma::api::client::{
device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
error::ErrorKind, error::ErrorKind,
@ -5,15 +6,17 @@ use ruma::api::client::{
}; };
use super::SESSION_ID_LENGTH; use super::SESSION_ID_LENGTH;
use crate::{services, utils, Error, Result, Ruma}; use crate::{utils, Error, Result, Ruma};
/// # `GET /_matrix/client/r0/devices` /// # `GET /_matrix/client/r0/devices`
/// ///
/// Get metadata on all devices of the sender user. /// Get metadata on all devices of the sender user.
pub(crate) async fn get_devices_route(body: Ruma<get_devices::v3::Request>) -> Result<get_devices::v3::Response> { pub(crate) async fn get_devices_route(
State(services): State<crate::State>, body: Ruma<get_devices::v3::Request>,
) -> Result<get_devices::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let devices: Vec<device::Device> = services() let devices: Vec<device::Device> = services
.users .users
.all_devices_metadata(sender_user) .all_devices_metadata(sender_user)
.filter_map(Result::ok) // Filter out buggy devices .filter_map(Result::ok) // Filter out buggy devices
@ -27,10 +30,12 @@ pub(crate) async fn get_devices_route(body: Ruma<get_devices::v3::Request>) -> R
/// # `GET /_matrix/client/r0/devices/{deviceId}` /// # `GET /_matrix/client/r0/devices/{deviceId}`
/// ///
/// Get metadata on a single device of the sender user. /// Get metadata on a single device of the sender user.
pub(crate) async fn get_device_route(body: Ruma<get_device::v3::Request>) -> Result<get_device::v3::Response> { pub(crate) async fn get_device_route(
State(services): State<crate::State>, body: Ruma<get_device::v3::Request>,
) -> Result<get_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let device = services() let device = services
.users .users
.get_device_metadata(sender_user, &body.body.device_id)? .get_device_metadata(sender_user, &body.body.device_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
@ -43,17 +48,19 @@ pub(crate) async fn get_device_route(body: Ruma<get_device::v3::Request>) -> Res
/// # `PUT /_matrix/client/r0/devices/{deviceId}` /// # `PUT /_matrix/client/r0/devices/{deviceId}`
/// ///
/// Updates the metadata on a given device of the sender user. /// Updates the metadata on a given device of the sender user.
pub(crate) async fn update_device_route(body: Ruma<update_device::v3::Request>) -> Result<update_device::v3::Response> { pub(crate) async fn update_device_route(
State(services): State<crate::State>, body: Ruma<update_device::v3::Request>,
) -> Result<update_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mut device = services() let mut device = services
.users .users
.get_device_metadata(sender_user, &body.device_id)? .get_device_metadata(sender_user, &body.device_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
device.display_name.clone_from(&body.display_name); device.display_name.clone_from(&body.display_name);
services() services
.users .users
.update_device_metadata(sender_user, &body.device_id, &device)?; .update_device_metadata(sender_user, &body.device_id, &device)?;
@ -70,7 +77,9 @@ pub(crate) async fn update_device_route(body: Ruma<update_device::v3::Request>)
/// last seen ts) /// last seen ts)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub(crate) async fn delete_device_route(body: Ruma<delete_device::v3::Request>) -> Result<delete_device::v3::Response> { pub(crate) async fn delete_device_route(
State(services): State<crate::State>, body: Ruma<delete_device::v3::Request>,
) -> Result<delete_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
@ -86,7 +95,7 @@ pub(crate) async fn delete_device_route(body: Ruma<delete_device::v3::Request>)
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services() let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; .try_auth(sender_user, sender_device, auth, &uiaainfo)?;
if !worked { if !worked {
@ -95,7 +104,7 @@ pub(crate) async fn delete_device_route(body: Ruma<delete_device::v3::Request>)
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services() services
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -103,9 +112,7 @@ pub(crate) async fn delete_device_route(body: Ruma<delete_device::v3::Request>)
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }
services() services.users.remove_device(sender_user, &body.device_id)?;
.users
.remove_device(sender_user, &body.device_id)?;
Ok(delete_device::v3::Response {}) Ok(delete_device::v3::Response {})
} }
@ -123,7 +130,7 @@ pub(crate) async fn delete_device_route(body: Ruma<delete_device::v3::Request>)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub(crate) async fn delete_devices_route( pub(crate) async fn delete_devices_route(
body: Ruma<delete_devices::v3::Request>, State(services): State<crate::State>, body: Ruma<delete_devices::v3::Request>,
) -> Result<delete_devices::v3::Response> { ) -> Result<delete_devices::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
@ -140,7 +147,7 @@ pub(crate) async fn delete_devices_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services() let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; .try_auth(sender_user, sender_device, auth, &uiaainfo)?;
if !worked { if !worked {
@ -149,7 +156,7 @@ pub(crate) async fn delete_devices_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services() services
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -158,7 +165,7 @@ pub(crate) async fn delete_devices_route(
} }
for device_id in &body.devices { for device_id in &body.devices {
services().users.remove_device(sender_user, device_id)?; services.users.remove_device(sender_user, device_id)?;
} }
Ok(delete_devices::v3::Response {}) Ok(delete_devices::v3::Response {})

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::{err, info, warn, Error, Result}; use conduit::{err, info, warn, Error, Result};
use ruma::{ use ruma::{
@ -20,7 +21,10 @@ use ruma::{
uint, RoomId, ServerName, UInt, UserId, uint, RoomId, ServerName, UInt, UserId,
}; };
use crate::{service::server_is_ours, services, Ruma}; use crate::{
service::{server_is_ours, Services},
Ruma,
};
/// # `POST /_matrix/client/v3/publicRooms` /// # `POST /_matrix/client/v3/publicRooms`
/// ///
@ -29,10 +33,11 @@ use crate::{service::server_is_ours, services, Ruma};
/// - Rooms are ordered by the number of joined members /// - Rooms are ordered by the number of joined members
#[tracing::instrument(skip_all, fields(%client), name = "publicrooms")] #[tracing::instrument(skip_all, fields(%client), name = "publicrooms")]
pub(crate) async fn get_public_rooms_filtered_route( pub(crate) async fn get_public_rooms_filtered_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_public_rooms_filtered::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_public_rooms_filtered::v3::Request>,
) -> Result<get_public_rooms_filtered::v3::Response> { ) -> Result<get_public_rooms_filtered::v3::Response> {
if let Some(server) = &body.server { if let Some(server) = &body.server {
if services() if services
.globals .globals
.forbidden_remote_room_directory_server_names() .forbidden_remote_room_directory_server_names()
.contains(server) .contains(server)
@ -45,6 +50,7 @@ pub(crate) async fn get_public_rooms_filtered_route(
} }
let response = get_public_rooms_filtered_helper( let response = get_public_rooms_filtered_helper(
services,
body.server.as_deref(), body.server.as_deref(),
body.limit, body.limit,
body.since.as_deref(), body.since.as_deref(),
@ -67,10 +73,11 @@ pub(crate) async fn get_public_rooms_filtered_route(
/// - Rooms are ordered by the number of joined members /// - Rooms are ordered by the number of joined members
#[tracing::instrument(skip_all, fields(%client), name = "publicrooms")] #[tracing::instrument(skip_all, fields(%client), name = "publicrooms")]
pub(crate) async fn get_public_rooms_route( pub(crate) async fn get_public_rooms_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_public_rooms::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_public_rooms::v3::Request>,
) -> Result<get_public_rooms::v3::Response> { ) -> Result<get_public_rooms::v3::Response> {
if let Some(server) = &body.server { if let Some(server) = &body.server {
if services() if services
.globals .globals
.forbidden_remote_room_directory_server_names() .forbidden_remote_room_directory_server_names()
.contains(server) .contains(server)
@ -83,6 +90,7 @@ pub(crate) async fn get_public_rooms_route(
} }
let response = get_public_rooms_filtered_helper( let response = get_public_rooms_filtered_helper(
services,
body.server.as_deref(), body.server.as_deref(),
body.limit, body.limit,
body.since.as_deref(), body.since.as_deref(),
@ -108,16 +116,17 @@ pub(crate) async fn get_public_rooms_route(
/// Sets the visibility of a given room in the room directory. /// Sets the visibility of a given room in the room directory.
#[tracing::instrument(skip_all, fields(%client), name = "room_directory")] #[tracing::instrument(skip_all, fields(%client), name = "room_directory")]
pub(crate) async fn set_room_visibility_route( pub(crate) async fn set_room_visibility_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<set_room_visibility::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<set_room_visibility::v3::Request>,
) -> Result<set_room_visibility::v3::Response> { ) -> Result<set_room_visibility::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().rooms.metadata.exists(&body.room_id)? { if !services.rooms.metadata.exists(&body.room_id)? {
// Return 404 if the room doesn't exist // Return 404 if the room doesn't exist
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
} }
if !user_can_publish_room(sender_user, &body.room_id)? { if !user_can_publish_room(services, sender_user, &body.room_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"User is not allowed to publish this room", "User is not allowed to publish this room",
@ -126,7 +135,7 @@ pub(crate) async fn set_room_visibility_route(
match &body.visibility { match &body.visibility {
room::Visibility::Public => { room::Visibility::Public => {
if services().globals.config.lockdown_public_room_directory && !services().users.is_admin(sender_user)? { if services.globals.config.lockdown_public_room_directory && !services.users.is_admin(sender_user)? {
info!( info!(
"Non-admin user {sender_user} tried to publish {0} to the room directory while \ "Non-admin user {sender_user} tried to publish {0} to the room directory while \
\"lockdown_public_room_directory\" is enabled", \"lockdown_public_room_directory\" is enabled",
@ -139,10 +148,10 @@ pub(crate) async fn set_room_visibility_route(
)); ));
} }
services().rooms.directory.set_public(&body.room_id)?; services.rooms.directory.set_public(&body.room_id)?;
info!("{sender_user} made {0} public", body.room_id); info!("{sender_user} made {0} public", body.room_id);
}, },
room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?, room::Visibility::Private => services.rooms.directory.set_not_public(&body.room_id)?,
_ => { _ => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -158,15 +167,15 @@ pub(crate) async fn set_room_visibility_route(
/// ///
/// Gets the visibility of a given room in the room directory. /// Gets the visibility of a given room in the room directory.
pub(crate) async fn get_room_visibility_route( pub(crate) async fn get_room_visibility_route(
body: Ruma<get_room_visibility::v3::Request>, State(services): State<crate::State>, body: Ruma<get_room_visibility::v3::Request>,
) -> Result<get_room_visibility::v3::Response> { ) -> Result<get_room_visibility::v3::Response> {
if !services().rooms.metadata.exists(&body.room_id)? { if !services.rooms.metadata.exists(&body.room_id)? {
// Return 404 if the room doesn't exist // Return 404 if the room doesn't exist
return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found"));
} }
Ok(get_room_visibility::v3::Response { Ok(get_room_visibility::v3::Response {
visibility: if services().rooms.directory.is_public_room(&body.room_id)? { visibility: if services.rooms.directory.is_public_room(&body.room_id)? {
room::Visibility::Public room::Visibility::Public
} else { } else {
room::Visibility::Private room::Visibility::Private
@ -175,10 +184,11 @@ pub(crate) async fn get_room_visibility_route(
} }
pub(crate) async fn get_public_rooms_filtered_helper( pub(crate) async fn get_public_rooms_filtered_helper(
server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, filter: &Filter, _network: &RoomNetwork, services: &Services, server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, filter: &Filter,
_network: &RoomNetwork,
) -> Result<get_public_rooms_filtered::v3::Response> { ) -> Result<get_public_rooms_filtered::v3::Response> {
if let Some(other_server) = server.filter(|server_name| !server_is_ours(server_name)) { if let Some(other_server) = server.filter(|server_name| !server_is_ours(server_name)) {
let response = services() let response = services
.sending .sending
.send_federation_request( .send_federation_request(
other_server, other_server,
@ -224,7 +234,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
} }
} }
let mut all_rooms: Vec<_> = services() let mut all_rooms: Vec<_> = services
.rooms .rooms
.directory .directory
.public_rooms() .public_rooms()
@ -232,12 +242,12 @@ pub(crate) async fn get_public_rooms_filtered_helper(
let room_id = room_id?; let room_id = room_id?;
let chunk = PublicRoomsChunk { let chunk = PublicRoomsChunk {
canonical_alias: services() canonical_alias: services
.rooms .rooms
.state_accessor .state_accessor
.get_canonical_alias(&room_id)?, .get_canonical_alias(&room_id)?,
name: services().rooms.state_accessor.get_name(&room_id)?, name: services.rooms.state_accessor.get_name(&room_id)?,
num_joined_members: services() num_joined_members: services
.rooms .rooms
.state_cache .state_cache
.room_joined_count(&room_id)? .room_joined_count(&room_id)?
@ -247,24 +257,24 @@ pub(crate) async fn get_public_rooms_filtered_helper(
}) })
.try_into() .try_into()
.expect("user count should not be that big"), .expect("user count should not be that big"),
topic: services() topic: services
.rooms .rooms
.state_accessor .state_accessor
.get_room_topic(&room_id) .get_room_topic(&room_id)
.unwrap_or(None), .unwrap_or(None),
world_readable: services().rooms.state_accessor.is_world_readable(&room_id)?, world_readable: services.rooms.state_accessor.is_world_readable(&room_id)?,
guest_can_join: services() guest_can_join: services
.rooms .rooms
.state_accessor .state_accessor
.guest_can_join(&room_id)?, .guest_can_join(&room_id)?,
avatar_url: services() avatar_url: services
.rooms .rooms
.state_accessor .state_accessor
.get_avatar(&room_id)? .get_avatar(&room_id)?
.into_option() .into_option()
.unwrap_or_default() .unwrap_or_default()
.url, .url,
join_rule: services() join_rule: services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")?
@ -282,7 +292,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
.transpose()? .transpose()?
.flatten() .flatten()
.ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?,
room_type: services() room_type: services
.rooms .rooms
.state_accessor .state_accessor
.get_room_type(&room_id)?, .get_room_type(&room_id)?,
@ -361,9 +371,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
/// Check whether the user can publish to the room directory via power levels of /// Check whether the user can publish to the room directory via power levels of
/// room history visibility event or room creator /// room history visibility event or room creator
fn user_can_publish_room(user_id: &UserId, room_id: &RoomId) -> Result<bool> { fn user_can_publish_room(services: &Services, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
if let Some(event) = if let Some(event) = services
services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
@ -374,7 +383,7 @@ fn user_can_publish_room(user_id: &UserId, room_id: &RoomId) -> Result<bool> {
RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility) RoomPowerLevels::from(content).user_can_send_state(user_id, StateEventType::RoomHistoryVisibility)
}) })
} else if let Some(event) = } else if let Some(event) =
services() services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")? .room_state_get(room_id, &StateEventType::RoomCreate, "")?

View file

@ -1,18 +1,21 @@
use axum::extract::State;
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
filter::{create_filter, get_filter}, filter::{create_filter, get_filter},
}; };
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
/// ///
/// Loads a filter that was previously created. /// Loads a filter that was previously created.
/// ///
/// - A user can only access their own filters /// - A user can only access their own filters
pub(crate) async fn get_filter_route(body: Ruma<get_filter::v3::Request>) -> Result<get_filter::v3::Response> { pub(crate) async fn get_filter_route(
State(services): State<crate::State>, body: Ruma<get_filter::v3::Request>,
) -> Result<get_filter::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let Some(filter) = services().users.get_filter(sender_user, &body.filter_id)? else { let Some(filter) = services.users.get_filter(sender_user, &body.filter_id)? else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found."));
}; };
@ -22,9 +25,11 @@ pub(crate) async fn get_filter_route(body: Ruma<get_filter::v3::Request>) -> Res
/// # `PUT /_matrix/client/r0/user/{userId}/filter` /// # `PUT /_matrix/client/r0/user/{userId}/filter`
/// ///
/// Creates a new filter to be used by other endpoints. /// Creates a new filter to be used by other endpoints.
pub(crate) async fn create_filter_route(body: Ruma<create_filter::v3::Request>) -> Result<create_filter::v3::Response> { pub(crate) async fn create_filter_route(
State(services): State<crate::State>, body: Ruma<create_filter::v3::Request>,
) -> Result<create_filter::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(create_filter::v3::Response::new( Ok(create_filter::v3::Response::new(
services().users.create_filter(sender_user, &body.filter)?, services.users.create_filter(sender_user, &body.filter)?,
)) ))
} }

View file

@ -3,6 +3,7 @@ use std::{
time::Instant, time::Instant,
}; };
use axum::extract::State;
use conduit::{utils, utils::math::continue_exponential_backoff_secs, Error, Result}; use conduit::{utils, utils::math::continue_exponential_backoff_secs, Error, Result};
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
@ -22,7 +23,7 @@ use service::user_is_local;
use tracing::debug; use tracing::debug;
use super::SESSION_ID_LENGTH; use super::SESSION_ID_LENGTH;
use crate::{services, Ruma}; use crate::{service::Services, Ruma};
/// # `POST /_matrix/client/r0/keys/upload` /// # `POST /_matrix/client/r0/keys/upload`
/// ///
@ -31,12 +32,14 @@ use crate::{services, Ruma};
/// - Adds one time keys /// - Adds one time keys
/// - If there are no device keys yet: Adds device keys (TODO: merge with /// - If there are no device keys yet: Adds device keys (TODO: merge with
/// existing keys?) /// existing keys?)
pub(crate) async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> Result<upload_keys::v3::Response> { pub(crate) async fn upload_keys_route(
State(services): State<crate::State>, body: Ruma<upload_keys::v3::Request>,
) -> Result<upload_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
for (key_key, key_value) in &body.one_time_keys { for (key_key, key_value) in &body.one_time_keys {
services() services
.users .users
.add_one_time_key(sender_user, sender_device, key_key, key_value)?; .add_one_time_key(sender_user, sender_device, key_key, key_value)?;
} }
@ -44,19 +47,19 @@ pub(crate) async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> R
if let Some(device_keys) = &body.device_keys { if let Some(device_keys) = &body.device_keys {
// TODO: merge this and the existing event? // TODO: merge this and the existing event?
// This check is needed to assure that signatures are kept // This check is needed to assure that signatures are kept
if services() if services
.users .users
.get_device_keys(sender_user, sender_device)? .get_device_keys(sender_user, sender_device)?
.is_none() .is_none()
{ {
services() services
.users .users
.add_device_keys(sender_user, sender_device, device_keys)?; .add_device_keys(sender_user, sender_device, device_keys)?;
} }
} }
Ok(upload_keys::v3::Response { Ok(upload_keys::v3::Response {
one_time_key_counts: services() one_time_key_counts: services
.users .users
.count_one_time_keys(sender_user, sender_device)?, .count_one_time_keys(sender_user, sender_device)?,
}) })
@ -70,10 +73,13 @@ pub(crate) async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> R
/// - Gets master keys, self-signing keys, user signing keys and device keys. /// - Gets master keys, self-signing keys, user signing keys and device keys.
/// - The master and self-signing keys contain signatures that the user is /// - The master and self-signing keys contain signatures that the user is
/// allowed to see /// allowed to see
pub(crate) async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_keys::v3::Response> { pub(crate) async fn get_keys_route(
State(services): State<crate::State>, body: Ruma<get_keys::v3::Request>,
) -> Result<get_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
get_keys_helper( get_keys_helper(
services,
Some(sender_user), Some(sender_user),
&body.device_keys, &body.device_keys,
|u| u == sender_user, |u| u == sender_user,
@ -85,8 +91,10 @@ pub(crate) async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<
/// # `POST /_matrix/client/r0/keys/claim` /// # `POST /_matrix/client/r0/keys/claim`
/// ///
/// Claims one-time keys /// Claims one-time keys
pub(crate) async fn claim_keys_route(body: Ruma<claim_keys::v3::Request>) -> Result<claim_keys::v3::Response> { pub(crate) async fn claim_keys_route(
claim_keys_helper(&body.one_time_keys).await State(services): State<crate::State>, body: Ruma<claim_keys::v3::Request>,
) -> Result<claim_keys::v3::Response> {
claim_keys_helper(services, &body.one_time_keys).await
} }
/// # `POST /_matrix/client/r0/keys/device_signing/upload` /// # `POST /_matrix/client/r0/keys/device_signing/upload`
@ -95,7 +103,7 @@ pub(crate) async fn claim_keys_route(body: Ruma<claim_keys::v3::Request>) -> Res
/// ///
/// - Requires UIAA to verify password /// - Requires UIAA to verify password
pub(crate) async fn upload_signing_keys_route( pub(crate) async fn upload_signing_keys_route(
body: Ruma<upload_signing_keys::v3::Request>, State(services): State<crate::State>, body: Ruma<upload_signing_keys::v3::Request>,
) -> Result<upload_signing_keys::v3::Response> { ) -> Result<upload_signing_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
@ -112,7 +120,7 @@ pub(crate) async fn upload_signing_keys_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = services() let (worked, uiaainfo) = services
.uiaa .uiaa
.try_auth(sender_user, sender_device, auth, &uiaainfo)?; .try_auth(sender_user, sender_device, auth, &uiaainfo)?;
if !worked { if !worked {
@ -121,7 +129,7 @@ pub(crate) async fn upload_signing_keys_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
services() services
.uiaa .uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -130,7 +138,7 @@ pub(crate) async fn upload_signing_keys_route(
} }
if let Some(master_key) = &body.master_key { if let Some(master_key) = &body.master_key {
services().users.add_cross_signing_keys( services.users.add_cross_signing_keys(
sender_user, sender_user,
master_key, master_key,
&body.self_signing_key, &body.self_signing_key,
@ -146,7 +154,7 @@ pub(crate) async fn upload_signing_keys_route(
/// ///
/// Uploads end-to-end key signatures from the sender user. /// Uploads end-to-end key signatures from the sender user.
pub(crate) async fn upload_signatures_route( pub(crate) async fn upload_signatures_route(
body: Ruma<upload_signatures::v3::Request>, State(services): State<crate::State>, body: Ruma<upload_signatures::v3::Request>,
) -> Result<upload_signatures::v3::Response> { ) -> Result<upload_signatures::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -173,7 +181,7 @@ pub(crate) async fn upload_signatures_route(
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))? .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))?
.to_owned(), .to_owned(),
); );
services() services
.users .users
.sign_key(user_id, key_id, signature, sender_user)?; .sign_key(user_id, key_id, signature, sender_user)?;
} }
@ -192,14 +200,14 @@ pub(crate) async fn upload_signatures_route(
/// ///
/// - TODO: left users /// - TODO: left users
pub(crate) async fn get_key_changes_route( pub(crate) async fn get_key_changes_route(
body: Ruma<get_key_changes::v3::Request>, State(services): State<crate::State>, body: Ruma<get_key_changes::v3::Request>,
) -> Result<get_key_changes::v3::Response> { ) -> Result<get_key_changes::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mut device_list_updates = HashSet::new(); let mut device_list_updates = HashSet::new();
device_list_updates.extend( device_list_updates.extend(
services() services
.users .users
.keys_changed( .keys_changed(
sender_user.as_str(), sender_user.as_str(),
@ -215,14 +223,14 @@ pub(crate) async fn get_key_changes_route(
.filter_map(Result::ok), .filter_map(Result::ok),
); );
for room_id in services() for room_id in services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(Result::ok) .filter_map(Result::ok)
{ {
device_list_updates.extend( device_list_updates.extend(
services() services
.users .users
.keys_changed( .keys_changed(
room_id.as_ref(), room_id.as_ref(),
@ -245,8 +253,8 @@ pub(crate) async fn get_key_changes_route(
} }
pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>( pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, allowed_signatures: F, services: &Services, sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
include_display_names: bool, allowed_signatures: F, include_display_names: bool,
) -> Result<get_keys::v3::Response> { ) -> Result<get_keys::v3::Response> {
let mut master_keys = BTreeMap::new(); let mut master_keys = BTreeMap::new();
let mut self_signing_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new();
@ -268,10 +276,10 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
if device_ids.is_empty() { if device_ids.is_empty() {
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for device_id in services().users.all_device_ids(user_id) { for device_id in services.users.all_device_ids(user_id) {
let device_id = device_id?; let device_id = device_id?;
if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? { if let Some(mut keys) = services.users.get_device_keys(user_id, &device_id)? {
let metadata = services() let metadata = services
.users .users
.get_device_metadata(user_id, &device_id)? .get_device_metadata(user_id, &device_id)?
.ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?; .ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?;
@ -286,8 +294,8 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
} else { } else {
for device_id in device_ids { for device_id in device_ids {
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { if let Some(mut keys) = services.users.get_device_keys(user_id, device_id)? {
let metadata = services() let metadata = services
.users .users
.get_device_metadata(user_id, device_id)? .get_device_metadata(user_id, device_id)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
@ -303,21 +311,21 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
} }
} }
if let Some(master_key) = services() if let Some(master_key) = services
.users .users
.get_master_key(sender_user, user_id, &allowed_signatures)? .get_master_key(sender_user, user_id, &allowed_signatures)?
{ {
master_keys.insert(user_id.to_owned(), master_key); master_keys.insert(user_id.to_owned(), master_key);
} }
if let Some(self_signing_key) = if let Some(self_signing_key) =
services() services
.users .users
.get_self_signing_key(sender_user, user_id, &allowed_signatures)? .get_self_signing_key(sender_user, user_id, &allowed_signatures)?
{ {
self_signing_keys.insert(user_id.to_owned(), self_signing_key); self_signing_keys.insert(user_id.to_owned(), self_signing_key);
} }
if Some(user_id) == sender_user { if Some(user_id) == sender_user {
if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? { if let Some(user_signing_key) = services.users.get_user_signing_key(user_id)? {
user_signing_keys.insert(user_id.to_owned(), user_signing_key); user_signing_keys.insert(user_id.to_owned(), user_signing_key);
} }
} }
@ -326,7 +334,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
let mut failures = BTreeMap::new(); let mut failures = BTreeMap::new();
let back_off = |id| async { let back_off = |id| async {
match services() match services
.globals .globals
.bad_query_ratelimiter .bad_query_ratelimiter
.write() .write()
@ -345,7 +353,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
let mut futures: FuturesUnordered<_> = get_over_federation let mut futures: FuturesUnordered<_> = get_over_federation
.into_iter() .into_iter()
.map(|(server, vec)| async move { .map(|(server, vec)| async move {
if let Some((time, tries)) = services() if let Some((time, tries)) = services
.globals .globals
.bad_query_ratelimiter .bad_query_ratelimiter
.read() .read()
@ -369,7 +377,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
let request = federation::keys::get_keys::v1::Request { let request = federation::keys::get_keys::v1::Request {
device_keys: device_keys_input_fed, device_keys: device_keys_input_fed,
}; };
let response = services() let response = services
.sending .sending
.send_federation_request(server, request) .send_federation_request(server, request)
.await; .await;
@ -381,19 +389,19 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool + Send>(
while let Some((server, response)) = futures.next().await { while let Some((server, response)) = futures.next().await {
if let Ok(Ok(response)) = response { if let Ok(Ok(response)) = response {
for (user, masterkey) in response.master_keys { for (user, masterkey) in response.master_keys {
let (master_key_id, mut master_key) = services().users.parse_master_key(&user, &masterkey)?; let (master_key_id, mut master_key) = services.users.parse_master_key(&user, &masterkey)?;
if let Some(our_master_key) = if let Some(our_master_key) =
services() services
.users .users
.get_key(&master_key_id, sender_user, &user, &allowed_signatures)? .get_key(&master_key_id, sender_user, &user, &allowed_signatures)?
{ {
let (_, our_master_key) = services().users.parse_master_key(&user, &our_master_key)?; let (_, our_master_key) = services.users.parse_master_key(&user, &our_master_key)?;
master_key.signatures.extend(our_master_key.signatures); master_key.signatures.extend(our_master_key.signatures);
} }
let json = serde_json::to_value(master_key).expect("to_value always works"); let json = serde_json::to_value(master_key).expect("to_value always works");
let raw = serde_json::from_value(json).expect("Raw::from_value always works"); let raw = serde_json::from_value(json).expect("Raw::from_value always works");
services().users.add_cross_signing_keys( services.users.add_cross_signing_keys(
&user, &raw, &None, &None, &user, &raw, &None, &None,
false, /* Dont notify. A notification would trigger another key request resulting in an false, /* Dont notify. A notification would trigger another key request resulting in an
* endless loop */ * endless loop */
@ -444,7 +452,7 @@ fn add_unsigned_device_display_name(
} }
pub(crate) async fn claim_keys_helper( pub(crate) async fn claim_keys_helper(
one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>, services: &Services, one_time_keys_input: &BTreeMap<OwnedUserId, BTreeMap<OwnedDeviceId, DeviceKeyAlgorithm>>,
) -> Result<claim_keys::v3::Response> { ) -> Result<claim_keys::v3::Response> {
let mut one_time_keys = BTreeMap::new(); let mut one_time_keys = BTreeMap::new();
@ -460,7 +468,7 @@ pub(crate) async fn claim_keys_helper(
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for (device_id, key_algorithm) in map { for (device_id, key_algorithm) in map {
if let Some(one_time_keys) = services() if let Some(one_time_keys) = services
.users .users
.take_one_time_key(user_id, device_id, key_algorithm)? .take_one_time_key(user_id, device_id, key_algorithm)?
{ {
@ -483,7 +491,7 @@ pub(crate) async fn claim_keys_helper(
} }
( (
server, server,
services() services
.sending .sending
.send_federation_request( .send_federation_request(
server, server,

View file

@ -2,6 +2,7 @@
use std::{io::Cursor, sync::Arc, time::Duration}; use std::{io::Cursor, sync::Arc, time::Duration};
use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::{debug, error, utils::math::ruma_from_usize, warn}; use conduit::{debug, error, utils::math::ruma_from_usize, warn};
use image::io::Reader as ImgReader; use image::io::Reader as ImgReader;
@ -20,9 +21,8 @@ use crate::{
debug_warn, debug_warn,
service::{ service::{
media::{FileMeta, UrlPreviewData}, media::{FileMeta, UrlPreviewData},
server_is_ours, server_is_ours, Services,
}, },
services,
utils::{ utils::{
self, self,
content_disposition::{content_disposition_type, make_content_disposition, sanitise_filename}, content_disposition::{content_disposition_type, make_content_disposition, sanitise_filename},
@ -42,10 +42,10 @@ const CORP_CROSS_ORIGIN: &str = "cross-origin";
/// ///
/// Returns max upload size. /// Returns max upload size.
pub(crate) async fn get_media_config_route( pub(crate) async fn get_media_config_route(
_body: Ruma<get_media_config::v3::Request>, State(services): State<crate::State>, _body: Ruma<get_media_config::v3::Request>,
) -> Result<get_media_config::v3::Response> { ) -> Result<get_media_config::v3::Response> {
Ok(get_media_config::v3::Response { Ok(get_media_config::v3::Response {
upload_size: ruma_from_usize(services().globals.config.max_request_size), upload_size: ruma_from_usize(services.globals.config.max_request_size),
}) })
} }
@ -57,9 +57,11 @@ pub(crate) async fn get_media_config_route(
/// ///
/// Returns max upload size. /// Returns max upload size.
pub(crate) async fn get_media_config_v1_route( pub(crate) async fn get_media_config_v1_route(
body: Ruma<get_media_config::v3::Request>, State(services): State<crate::State>, body: Ruma<get_media_config::v3::Request>,
) -> Result<RumaResponse<get_media_config::v3::Response>> { ) -> Result<RumaResponse<get_media_config::v3::Response>> {
get_media_config_route(body).await.map(RumaResponse) get_media_config_route(State(services), body)
.await
.map(RumaResponse)
} }
/// # `GET /_matrix/media/v3/preview_url` /// # `GET /_matrix/media/v3/preview_url`
@ -67,17 +69,18 @@ pub(crate) async fn get_media_config_v1_route(
/// Returns URL preview. /// Returns URL preview.
#[tracing::instrument(skip_all, fields(%client), name = "url_preview")] #[tracing::instrument(skip_all, fields(%client), name = "url_preview")]
pub(crate) async fn get_media_preview_route( pub(crate) async fn get_media_preview_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_media_preview::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_media_preview::v3::Request>,
) -> Result<get_media_preview::v3::Response> { ) -> Result<get_media_preview::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let url = &body.url; let url = &body.url;
if !url_preview_allowed(url) { if !url_preview_allowed(services, url) {
warn!(%sender_user, "URL is not allowed to be previewed: {url}"); warn!(%sender_user, "URL is not allowed to be previewed: {url}");
return Err(Error::BadRequest(ErrorKind::forbidden(), "URL is not allowed to be previewed")); return Err(Error::BadRequest(ErrorKind::forbidden(), "URL is not allowed to be previewed"));
} }
match get_url_preview(url).await { match get_url_preview(services, url).await {
Ok(preview) => { Ok(preview) => {
let res = serde_json::value::to_raw_value(&preview).map_err(|e| { let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
error!(%sender_user, "Failed to convert UrlPreviewData into a serde json value: {e}"); error!(%sender_user, "Failed to convert UrlPreviewData into a serde json value: {e}");
@ -115,9 +118,10 @@ pub(crate) async fn get_media_preview_route(
/// Returns URL preview. /// Returns URL preview.
#[tracing::instrument(skip_all, fields(%client), name = "url_preview")] #[tracing::instrument(skip_all, fields(%client), name = "url_preview")]
pub(crate) async fn get_media_preview_v1_route( pub(crate) async fn get_media_preview_v1_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_media_preview::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_media_preview::v3::Request>,
) -> Result<RumaResponse<get_media_preview::v3::Response>> { ) -> Result<RumaResponse<get_media_preview::v3::Response>> {
get_media_preview_route(InsecureClientIp(client), body) get_media_preview_route(State(services), InsecureClientIp(client), body)
.await .await
.map(RumaResponse) .map(RumaResponse)
} }
@ -130,17 +134,14 @@ pub(crate) async fn get_media_preview_v1_route(
/// - Media will be saved in the media/ directory /// - Media will be saved in the media/ directory
#[tracing::instrument(skip_all, fields(%client), name = "media_upload")] #[tracing::instrument(skip_all, fields(%client), name = "media_upload")]
pub(crate) async fn create_content_route( pub(crate) async fn create_content_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<create_content::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<create_content::v3::Request>,
) -> Result<create_content::v3::Response> { ) -> Result<create_content::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mxc = format!( let mxc = format!("mxc://{}/{}", services.globals.server_name(), utils::random_string(MXC_LENGTH));
"mxc://{}/{}",
services().globals.server_name(),
utils::random_string(MXC_LENGTH)
);
services() services
.media .media
.create( .create(
Some(sender_user.clone()), Some(sender_user.clone()),
@ -178,9 +179,10 @@ pub(crate) async fn create_content_route(
/// - Media will be saved in the media/ directory /// - Media will be saved in the media/ directory
#[tracing::instrument(skip_all, fields(%client), name = "media_upload")] #[tracing::instrument(skip_all, fields(%client), name = "media_upload")]
pub(crate) async fn create_content_v1_route( pub(crate) async fn create_content_v1_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<create_content::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<create_content::v3::Request>,
) -> Result<RumaResponse<create_content::v3::Response>> { ) -> Result<RumaResponse<create_content::v3::Response>> {
create_content_route(InsecureClientIp(client), body) create_content_route(State(services), InsecureClientIp(client), body)
.await .await
.map(RumaResponse) .map(RumaResponse)
} }
@ -195,7 +197,8 @@ pub(crate) async fn create_content_v1_route(
/// seconds /// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_get")] #[tracing::instrument(skip_all, fields(%client), name = "media_get")]
pub(crate) async fn get_content_route( pub(crate) async fn get_content_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content::v3::Request>,
) -> Result<get_content::v3::Response> { ) -> Result<get_content::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
@ -203,7 +206,7 @@ pub(crate) async fn get_content_route(
content, content,
content_type, content_type,
content_disposition, content_disposition,
}) = services().media.get(&mxc).await? }) = services.media.get(&mxc).await?
{ {
let content_disposition = Some(make_content_disposition(&content_type, content_disposition, None)); let content_disposition = Some(make_content_disposition(&content_type, content_disposition, None));
let file = content.expect("content"); let file = content.expect("content");
@ -217,6 +220,7 @@ pub(crate) async fn get_content_route(
}) })
} else if !server_is_ours(&body.server_name) && body.allow_remote { } else if !server_is_ours(&body.server_name) && body.allow_remote {
let response = get_remote_content( let response = get_remote_content(
services,
&mxc, &mxc,
&body.server_name, &body.server_name,
body.media_id.clone(), body.media_id.clone(),
@ -261,9 +265,10 @@ pub(crate) async fn get_content_route(
/// seconds /// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_get")] #[tracing::instrument(skip_all, fields(%client), name = "media_get")]
pub(crate) async fn get_content_v1_route( pub(crate) async fn get_content_v1_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content::v3::Request>,
) -> Result<RumaResponse<get_content::v3::Response>> { ) -> Result<RumaResponse<get_content::v3::Response>> {
get_content_route(InsecureClientIp(client), body) get_content_route(State(services), InsecureClientIp(client), body)
.await .await
.map(RumaResponse) .map(RumaResponse)
} }
@ -278,7 +283,8 @@ pub(crate) async fn get_content_v1_route(
/// seconds /// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_get")] #[tracing::instrument(skip_all, fields(%client), name = "media_get")]
pub(crate) async fn get_content_as_filename_route( pub(crate) async fn get_content_as_filename_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_as_filename::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content_as_filename::v3::Request>,
) -> Result<get_content_as_filename::v3::Response> { ) -> Result<get_content_as_filename::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
@ -286,7 +292,7 @@ pub(crate) async fn get_content_as_filename_route(
content, content,
content_type, content_type,
content_disposition, content_disposition,
}) = services().media.get(&mxc).await? }) = services.media.get(&mxc).await?
{ {
let content_disposition = Some(make_content_disposition( let content_disposition = Some(make_content_disposition(
&content_type, &content_type,
@ -304,6 +310,7 @@ pub(crate) async fn get_content_as_filename_route(
}) })
} else if !server_is_ours(&body.server_name) && body.allow_remote { } else if !server_is_ours(&body.server_name) && body.allow_remote {
match get_remote_content( match get_remote_content(
services,
&mxc, &mxc,
&body.server_name, &body.server_name,
body.media_id.clone(), body.media_id.clone(),
@ -351,9 +358,10 @@ pub(crate) async fn get_content_as_filename_route(
/// seconds /// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_get")] #[tracing::instrument(skip_all, fields(%client), name = "media_get")]
pub(crate) async fn get_content_as_filename_v1_route( pub(crate) async fn get_content_as_filename_v1_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_as_filename::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content_as_filename::v3::Request>,
) -> Result<RumaResponse<get_content_as_filename::v3::Response>> { ) -> Result<RumaResponse<get_content_as_filename::v3::Response>> {
get_content_as_filename_route(InsecureClientIp(client), body) get_content_as_filename_route(State(services), InsecureClientIp(client), body)
.await .await
.map(RumaResponse) .map(RumaResponse)
} }
@ -368,7 +376,8 @@ pub(crate) async fn get_content_as_filename_v1_route(
/// seconds /// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")] #[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")]
pub(crate) async fn get_content_thumbnail_route( pub(crate) async fn get_content_thumbnail_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_thumbnail::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content_thumbnail::v3::Request>,
) -> Result<get_content_thumbnail::v3::Response> { ) -> Result<get_content_thumbnail::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
@ -376,7 +385,7 @@ pub(crate) async fn get_content_thumbnail_route(
content, content,
content_type, content_type,
content_disposition, content_disposition,
}) = services() }) = services
.media .media
.get_thumbnail( .get_thumbnail(
&mxc, &mxc,
@ -400,7 +409,7 @@ pub(crate) async fn get_content_thumbnail_route(
content_disposition, content_disposition,
}) })
} else if !server_is_ours(&body.server_name) && body.allow_remote { } else if !server_is_ours(&body.server_name) && body.allow_remote {
if services() if services
.globals .globals
.prevent_media_downloads_from() .prevent_media_downloads_from()
.contains(&body.server_name) .contains(&body.server_name)
@ -411,7 +420,7 @@ pub(crate) async fn get_content_thumbnail_route(
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
} }
match services() match services
.sending .sending
.send_federation_request( .send_federation_request(
&body.server_name, &body.server_name,
@ -430,7 +439,7 @@ pub(crate) async fn get_content_thumbnail_route(
.await .await
{ {
Ok(get_thumbnail_response) => { Ok(get_thumbnail_response) => {
services() services
.media .media
.upload_thumbnail( .upload_thumbnail(
None, None,
@ -481,17 +490,19 @@ pub(crate) async fn get_content_thumbnail_route(
/// seconds /// seconds
#[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")] #[tracing::instrument(skip_all, fields(%client), name = "media_thumbnail_get")]
pub(crate) async fn get_content_thumbnail_v1_route( pub(crate) async fn get_content_thumbnail_v1_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_content_thumbnail::v3::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_content_thumbnail::v3::Request>,
) -> Result<RumaResponse<get_content_thumbnail::v3::Response>> { ) -> Result<RumaResponse<get_content_thumbnail::v3::Response>> {
get_content_thumbnail_route(InsecureClientIp(client), body) get_content_thumbnail_route(State(services), InsecureClientIp(client), body)
.await .await
.map(RumaResponse) .map(RumaResponse)
} }
async fn get_remote_content( async fn get_remote_content(
mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool, timeout_ms: Duration, services: &Services, mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool,
timeout_ms: Duration,
) -> Result<get_content::v3::Response, Error> { ) -> Result<get_content::v3::Response, Error> {
if services() if services
.globals .globals
.prevent_media_downloads_from() .prevent_media_downloads_from()
.contains(&server_name.to_owned()) .contains(&server_name.to_owned())
@ -502,7 +513,7 @@ async fn get_remote_content(
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
} }
let content_response = services() let content_response = services
.sending .sending
.send_federation_request( .send_federation_request(
server_name, server_name,
@ -522,7 +533,7 @@ async fn get_remote_content(
None, None,
)); ));
services() services
.media .media
.create( .create(
None, None,
@ -542,15 +553,11 @@ async fn get_remote_content(
}) })
} }
async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> { async fn download_image(services: &Services, client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
let image = client.get(url).send().await?.bytes().await?; let image = client.get(url).send().await?.bytes().await?;
let mxc = format!( let mxc = format!("mxc://{}/{}", services.globals.server_name(), utils::random_string(MXC_LENGTH));
"mxc://{}/{}",
services().globals.server_name(),
utils::random_string(MXC_LENGTH)
);
services() services
.media .media
.create(None, &mxc, None, None, &image) .create(None, &mxc, None, None, &image)
.await?; .await?;
@ -572,18 +579,18 @@ async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPrevie
}) })
} }
async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> { async fn download_html(services: &Services, client: &reqwest::Client, url: &str) -> Result<UrlPreviewData> {
let mut response = client.get(url).send().await?; let mut response = client.get(url).send().await?;
let mut bytes: Vec<u8> = Vec::new(); let mut bytes: Vec<u8> = Vec::new();
while let Some(chunk) = response.chunk().await? { while let Some(chunk) = response.chunk().await? {
bytes.extend_from_slice(&chunk); bytes.extend_from_slice(&chunk);
if bytes.len() > services().globals.url_preview_max_spider_size() { if bytes.len() > services.globals.url_preview_max_spider_size() {
debug!( debug!(
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \ "Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \
response body and assuming our necessary data is in this range.", response body and assuming our necessary data is in this range.",
url, url,
services().globals.url_preview_max_spider_size() services.globals.url_preview_max_spider_size()
); );
break; break;
} }
@ -595,7 +602,7 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreview
let mut data = match html.opengraph.images.first() { let mut data = match html.opengraph.images.first() {
None => UrlPreviewData::default(), None => UrlPreviewData::default(),
Some(obj) => download_image(client, &obj.url).await?, Some(obj) => download_image(services, client, &obj.url).await?,
}; };
let props = html.opengraph.properties; let props = html.opengraph.properties;
@ -607,19 +614,19 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreview
Ok(data) Ok(data)
} }
async fn request_url_preview(url: &str) -> Result<UrlPreviewData> { async fn request_url_preview(services: &Services, url: &str) -> Result<UrlPreviewData> {
if let Ok(ip) = IPAddress::parse(url) { if let Ok(ip) = IPAddress::parse(url) {
if !services().globals.valid_cidr_range(&ip) { if !services.globals.valid_cidr_range(&ip) {
return Err(Error::BadServerResponse("Requesting from this address is forbidden")); return Err(Error::BadServerResponse("Requesting from this address is forbidden"));
} }
} }
let client = &services().globals.client.url_preview; let client = &services.globals.client.url_preview;
let response = client.head(url).send().await?; let response = client.head(url).send().await?;
if let Some(remote_addr) = response.remote_addr() { if let Some(remote_addr) = response.remote_addr() {
if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) { if let Ok(ip) = IPAddress::parse(remote_addr.ip().to_string()) {
if !services().globals.valid_cidr_range(&ip) { if !services.globals.valid_cidr_range(&ip) {
return Err(Error::BadServerResponse("Requesting from this address is forbidden")); return Err(Error::BadServerResponse("Requesting from this address is forbidden"));
} }
} }
@ -633,24 +640,24 @@ async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")); return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type"));
}; };
let data = match content_type { let data = match content_type {
html if html.starts_with("text/html") => download_html(client, url).await?, html if html.starts_with("text/html") => download_html(services, client, url).await?,
img if img.starts_with("image/") => download_image(client, url).await?, img if img.starts_with("image/") => download_image(services, client, url).await?,
_ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")), _ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")),
}; };
services().media.set_url_preview(url, &data).await?; services.media.set_url_preview(url, &data).await?;
Ok(data) Ok(data)
} }
async fn get_url_preview(url: &str) -> Result<UrlPreviewData> { async fn get_url_preview(services: &Services, url: &str) -> Result<UrlPreviewData> {
if let Some(preview) = services().media.get_url_preview(url).await { if let Some(preview) = services.media.get_url_preview(url).await {
return Ok(preview); return Ok(preview);
} }
// ensure that only one request is made per URL // ensure that only one request is made per URL
let mutex_request = Arc::clone( let mutex_request = Arc::clone(
services() services
.media .media
.url_preview_mutex .url_preview_mutex
.write() .write()
@ -660,13 +667,13 @@ async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
); );
let _request_lock = mutex_request.lock().await; let _request_lock = mutex_request.lock().await;
match services().media.get_url_preview(url).await { match services.media.get_url_preview(url).await {
Some(preview) => Ok(preview), Some(preview) => Ok(preview),
None => request_url_preview(url).await, None => request_url_preview(services, url).await,
} }
} }
fn url_preview_allowed(url_str: &str) -> bool { fn url_preview_allowed(services: &Services, url_str: &str) -> bool {
let url: Url = match Url::parse(url_str) { let url: Url = match Url::parse(url_str) {
Ok(u) => u, Ok(u) => u,
Err(e) => { Err(e) => {
@ -691,10 +698,10 @@ fn url_preview_allowed(url_str: &str) -> bool {
Some(h) => h.to_owned(), Some(h) => h.to_owned(),
}; };
let allowlist_domain_contains = services().globals.url_preview_domain_contains_allowlist(); let allowlist_domain_contains = services.globals.url_preview_domain_contains_allowlist();
let allowlist_domain_explicit = services().globals.url_preview_domain_explicit_allowlist(); let allowlist_domain_explicit = services.globals.url_preview_domain_explicit_allowlist();
let denylist_domain_explicit = services().globals.url_preview_domain_explicit_denylist(); let denylist_domain_explicit = services.globals.url_preview_domain_explicit_denylist();
let allowlist_url_contains = services().globals.url_preview_url_contains_allowlist(); let allowlist_url_contains = services.globals.url_preview_url_contains_allowlist();
if allowlist_domain_contains.contains(&"*".to_owned()) if allowlist_domain_contains.contains(&"*".to_owned())
|| allowlist_domain_explicit.contains(&"*".to_owned()) || allowlist_domain_explicit.contains(&"*".to_owned())
@ -735,7 +742,7 @@ fn url_preview_allowed(url_str: &str) -> bool {
} }
// check root domain if available and if user has root domain checks // check root domain if available and if user has root domain checks
if services().globals.url_preview_check_root_domain() { if services.globals.url_preview_check_root_domain() {
debug!("Checking root domain"); debug!("Checking root domain");
match host.split_once('.') { match host.split_once('.') {
None => return false, None => return false,

File diff suppressed because it is too large Load diff

View file

@ -1,5 +1,6 @@
use std::collections::{BTreeMap, HashSet}; use std::collections::{BTreeMap, HashSet};
use axum::extract::State;
use conduit::PduCount; use conduit::PduCount;
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -12,7 +13,10 @@ use ruma::{
}; };
use serde_json::{from_str, Value}; use serde_json::{from_str, Value};
use crate::{service::pdu::PduBuilder, services, utils, Error, PduEvent, Result, Ruma}; use crate::{
service::{pdu::PduBuilder, Services},
utils, Error, PduEvent, Result, Ruma,
};
/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` /// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}`
/// ///
@ -24,21 +28,19 @@ use crate::{service::pdu::PduBuilder, services, utils, Error, PduEvent, Result,
/// - Tries to send the event into the room, auth rules will determine if it is /// - Tries to send the event into the room, auth rules will determine if it is
/// allowed /// allowed
pub(crate) async fn send_message_event_route( pub(crate) async fn send_message_event_route(
body: Ruma<send_message_event::v3::Request>, State(services): State<crate::State>, body: Ruma<send_message_event::v3::Request>,
) -> Result<send_message_event::v3::Response> { ) -> Result<send_message_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_deref(); let sender_device = body.sender_device.as_deref();
let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
// Forbid m.room.encrypted if encryption is disabled // Forbid m.room.encrypted if encryption is disabled
if MessageLikeEventType::RoomEncrypted == body.event_type && !services().globals.allow_encryption() { if MessageLikeEventType::RoomEncrypted == body.event_type && !services.globals.allow_encryption() {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled"));
} }
if body.event_type == MessageLikeEventType::CallInvite if body.event_type == MessageLikeEventType::CallInvite && services.rooms.directory.is_public_room(&body.room_id)? {
&& services().rooms.directory.is_public_room(&body.room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"Room call invites are not allowed in public rooms", "Room call invites are not allowed in public rooms",
@ -46,7 +48,7 @@ pub(crate) async fn send_message_event_route(
} }
// Check if this is a new transaction id // Check if this is a new transaction id
if let Some(response) = services() if let Some(response) = services
.transaction_ids .transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)? .existing_txnid(sender_user, sender_device, &body.txn_id)?
{ {
@ -71,7 +73,7 @@ pub(crate) async fn send_message_event_route(
let mut unsigned = BTreeMap::new(); let mut unsigned = BTreeMap::new();
unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
let event_id = services() let event_id = services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -89,7 +91,7 @@ pub(crate) async fn send_message_event_route(
) )
.await?; .await?;
services() services
.transaction_ids .transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; .add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?;
@ -105,7 +107,7 @@ pub(crate) async fn send_message_event_route(
/// - Only works if the user is joined (TODO: always allow, but only show events /// - Only works if the user is joined (TODO: always allow, but only show events
/// where the user was joined, depending on `history_visibility`) /// where the user was joined, depending on `history_visibility`)
pub(crate) async fn get_message_events_route( pub(crate) async fn get_message_events_route(
body: Ruma<get_message_events::v3::Request>, State(services): State<crate::State>, body: Ruma<get_message_events::v3::Request>,
) -> Result<get_message_events::v3::Response> { ) -> Result<get_message_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
@ -123,7 +125,7 @@ pub(crate) async fn get_message_events_route(
.as_ref() .as_ref()
.and_then(|t| PduCount::try_from_string(t).ok()); .and_then(|t| PduCount::try_from_string(t).ok());
services() services
.rooms .rooms
.lazy_loading .lazy_loading
.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from) .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)
@ -139,12 +141,12 @@ pub(crate) async fn get_message_events_route(
match body.dir { match body.dir {
ruma::api::Direction::Forward => { ruma::api::Direction::Forward => {
let events_after: Vec<_> = services() let events_after: Vec<_> = services
.rooms .rooms
.timeline .timeline
.pdus_after(sender_user, &body.room_id, from)? .pdus_after(sender_user, &body.room_id, from)?
.filter_map(Result::ok) // Filter out buggy events .filter_map(Result::ok) // Filter out buggy events
.filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(pdu, sender_user, &body.room_id) .filter(|(_, pdu)| { contains_url_filter(pdu, &body.filter) && visibility_filter(services, pdu, sender_user, &body.room_id)
}) })
.take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take_while(|&(k, _)| Some(k) != to) // Stop at `to`
@ -157,7 +159,7 @@ pub(crate) async fn get_message_events_route(
* https://github.com/vector-im/element-web/issues/21034 * https://github.com/vector-im/element-web/issues/21034
*/ */
if !cfg!(feature = "element_hacks") if !cfg!(feature = "element_hacks")
&& !services().rooms.lazy_loading.lazy_load_was_sent_before( && !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&body.room_id, &body.room_id,
@ -181,17 +183,17 @@ pub(crate) async fn get_message_events_route(
resp.chunk = events_after; resp.chunk = events_after;
}, },
ruma::api::Direction::Backward => { ruma::api::Direction::Backward => {
services() services
.rooms .rooms
.timeline .timeline
.backfill_if_required(&body.room_id, from) .backfill_if_required(&body.room_id, from)
.await?; .await?;
let events_before: Vec<_> = services() let events_before: Vec<_> = services
.rooms .rooms
.timeline .timeline
.pdus_until(sender_user, &body.room_id, from)? .pdus_until(sender_user, &body.room_id, from)?
.filter_map(Result::ok) // Filter out buggy events .filter_map(Result::ok) // Filter out buggy events
.filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(pdu, sender_user, &body.room_id)}) .filter(|(_, pdu)| {contains_url_filter(pdu, &body.filter) && visibility_filter(services, pdu, sender_user, &body.room_id)})
.take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take_while(|&(k, _)| Some(k) != to) // Stop at `to`
.take(limit) .take(limit)
.collect(); .collect();
@ -202,7 +204,7 @@ pub(crate) async fn get_message_events_route(
* https://github.com/vector-im/element-web/issues/21034 * https://github.com/vector-im/element-web/issues/21034
*/ */
if !cfg!(feature = "element_hacks") if !cfg!(feature = "element_hacks")
&& !services().rooms.lazy_loading.lazy_load_was_sent_before( && !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&body.room_id, &body.room_id,
@ -229,11 +231,12 @@ pub(crate) async fn get_message_events_route(
resp.state = Vec::new(); resp.state = Vec::new();
for ll_id in &lazy_loaded { for ll_id in &lazy_loaded {
if let Some(member_event) = services().rooms.state_accessor.room_state_get( if let Some(member_event) =
&body.room_id, services
&StateEventType::RoomMember, .rooms
ll_id.as_str(), .state_accessor
)? { .room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())?
{
resp.state.push(member_event.to_state_event()); resp.state.push(member_event.to_state_event());
} }
} }
@ -241,7 +244,7 @@ pub(crate) async fn get_message_events_route(
// remove the feature check when we are sure clients like element can handle it // remove the feature check when we are sure clients like element can handle it
if !cfg!(feature = "element_hacks") { if !cfg!(feature = "element_hacks") {
if let Some(next_token) = next_token { if let Some(next_token) = next_token {
services() services
.rooms .rooms
.lazy_loading .lazy_loading
.lazy_load_mark_sent(sender_user, sender_device, &body.room_id, lazy_loaded, next_token) .lazy_load_mark_sent(sender_user, sender_device, &body.room_id, lazy_loaded, next_token)
@ -252,8 +255,8 @@ pub(crate) async fn get_message_events_route(
Ok(resp) Ok(resp)
} }
fn visibility_filter(pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool { fn visibility_filter(services: &Services, pdu: &PduEvent, user_id: &UserId, room_id: &RoomId) -> bool {
services() services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(user_id, room_id, &pdu.event_id) .user_can_see_event(user_id, room_id, &pdu.event_id)

View file

@ -1,5 +1,6 @@
use std::time::Duration; use std::time::Duration;
use axum::extract::State;
use conduit::utils; use conduit::utils;
use ruma::{ use ruma::{
api::client::{account, error::ErrorKind}, api::client::{account, error::ErrorKind},
@ -7,7 +8,7 @@ use ruma::{
}; };
use super::TOKEN_LENGTH; use super::TOKEN_LENGTH;
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `POST /_matrix/client/v3/user/{userId}/openid/request_token` /// # `POST /_matrix/client/v3/user/{userId}/openid/request_token`
/// ///
@ -15,7 +16,7 @@ use crate::{services, Error, Result, Ruma};
/// ///
/// - The token generated is only valid for the OpenID API /// - The token generated is only valid for the OpenID API
pub(crate) async fn create_openid_token_route( pub(crate) async fn create_openid_token_route(
body: Ruma<account::request_openid_token::v3::Request>, State(services): State<crate::State>, body: Ruma<account::request_openid_token::v3::Request>,
) -> Result<account::request_openid_token::v3::Response> { ) -> Result<account::request_openid_token::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -28,14 +29,14 @@ pub(crate) async fn create_openid_token_route(
let access_token = utils::random_string(TOKEN_LENGTH); let access_token = utils::random_string(TOKEN_LENGTH);
let expires_in = services() let expires_in = services
.users .users
.create_openid_token(&body.user_id, &access_token)?; .create_openid_token(&body.user_id, &access_token)?;
Ok(account::request_openid_token::v3::Response { Ok(account::request_openid_token::v3::Response {
access_token, access_token,
token_type: TokenType::Bearer, token_type: TokenType::Bearer,
matrix_server_name: services().globals.config.server_name.clone(), matrix_server_name: services.globals.config.server_name.clone(),
expires_in: Duration::from_secs(expires_in), expires_in: Duration::from_secs(expires_in),
}) })
} }

View file

@ -1,22 +1,24 @@
use std::time::Duration; use std::time::Duration;
use axum::extract::State;
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
presence::{get_presence, set_presence}, presence::{get_presence, set_presence},
}; };
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/presence/{userId}/status` /// # `PUT /_matrix/client/r0/presence/{userId}/status`
/// ///
/// Sets the presence state of the sender user. /// Sets the presence state of the sender user.
pub(crate) async fn set_presence_route(body: Ruma<set_presence::v3::Request>) -> Result<set_presence::v3::Response> { pub(crate) async fn set_presence_route(
if !services().globals.allow_local_presence() { State(services): State<crate::State>, body: Ruma<set_presence::v3::Request>,
) -> Result<set_presence::v3::Response> {
if !services.globals.allow_local_presence() {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Presence is disabled on this server")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Presence is disabled on this server"));
} }
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if sender_user != &body.user_id && body.appservice_info.is_none() { if sender_user != &body.user_id && body.appservice_info.is_none() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -24,7 +26,7 @@ pub(crate) async fn set_presence_route(body: Ruma<set_presence::v3::Request>) ->
)); ));
} }
services() services
.presence .presence
.set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())?; .set_presence(sender_user, &body.presence, None, None, body.status_msg.clone())?;
@ -36,8 +38,10 @@ pub(crate) async fn set_presence_route(body: Ruma<set_presence::v3::Request>) ->
/// Gets the presence state of the given user. /// Gets the presence state of the given user.
/// ///
/// - Only works if you share a room with the user /// - Only works if you share a room with the user
pub(crate) async fn get_presence_route(body: Ruma<get_presence::v3::Request>) -> Result<get_presence::v3::Response> { pub(crate) async fn get_presence_route(
if !services().globals.allow_local_presence() { State(services): State<crate::State>, body: Ruma<get_presence::v3::Request>,
) -> Result<get_presence::v3::Response> {
if !services.globals.allow_local_presence() {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Presence is disabled on this server")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Presence is disabled on this server"));
} }
@ -45,12 +49,12 @@ pub(crate) async fn get_presence_route(body: Ruma<get_presence::v3::Request>) ->
let mut presence_event = None; let mut presence_event = None;
for _room_id in services() for _room_id in services
.rooms .rooms
.user .user
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
{ {
if let Some(presence) = services().presence.get_presence(&body.user_id)? { if let Some(presence) = services.presence.get_presence(&body.user_id)? {
presence_event = Some(presence); presence_event = Some(presence);
break; break;
} }

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
@ -14,8 +15,8 @@ use serde_json::value::to_raw_value;
use tracing::warn; use tracing::warn;
use crate::{ use crate::{
service::{pdu::PduBuilder, user_is_local}, service::{pdu::PduBuilder, user_is_local, Services},
services, Error, Result, Ruma, Error, Result, Ruma,
}; };
/// # `PUT /_matrix/client/r0/profile/{userId}/displayname` /// # `PUT /_matrix/client/r0/profile/{userId}/displayname`
@ -24,21 +25,21 @@ use crate::{
/// ///
/// - Also makes sure other users receive the update using presence EDUs /// - Also makes sure other users receive the update using presence EDUs
pub(crate) async fn set_displayname_route( pub(crate) async fn set_displayname_route(
body: Ruma<set_display_name::v3::Request>, State(services): State<crate::State>, body: Ruma<set_display_name::v3::Request>,
) -> Result<set_display_name::v3::Response> { ) -> Result<set_display_name::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let all_joined_rooms: Vec<OwnedRoomId> = services() let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(Result::ok) .filter_map(Result::ok)
.collect(); .collect();
update_displayname(sender_user.clone(), body.displayname.clone(), all_joined_rooms).await?; update_displayname(services, sender_user.clone(), body.displayname.clone(), all_joined_rooms).await?;
if services().globals.allow_local_presence() { if services.globals.allow_local_presence() {
// Presence update // Presence update
services() services
.presence .presence
.ping_presence(sender_user, &PresenceState::Online)?; .ping_presence(sender_user, &PresenceState::Online)?;
} }
@ -53,11 +54,11 @@ pub(crate) async fn set_displayname_route(
/// - If user is on another server and we do not have a local copy already fetch /// - If user is on another server and we do not have a local copy already fetch
/// displayname over federation /// displayname over federation
pub(crate) async fn get_displayname_route( pub(crate) async fn get_displayname_route(
body: Ruma<get_display_name::v3::Request>, State(services): State<crate::State>, body: Ruma<get_display_name::v3::Request>,
) -> Result<get_display_name::v3::Response> { ) -> Result<get_display_name::v3::Response> {
if !user_is_local(&body.user_id) { if !user_is_local(&body.user_id) {
// Create and update our local copy of the user // Create and update our local copy of the user
if let Ok(response) = services() if let Ok(response) = services
.sending .sending
.send_federation_request( .send_federation_request(
body.user_id.server_name(), body.user_id.server_name(),
@ -68,19 +69,19 @@ pub(crate) async fn get_displayname_route(
) )
.await .await
{ {
if !services().users.exists(&body.user_id)? { if !services.users.exists(&body.user_id)? {
services().users.create(&body.user_id, None)?; services.users.create(&body.user_id, None)?;
} }
services() services
.users .users
.set_displayname(&body.user_id, response.displayname.clone()) .set_displayname(&body.user_id, response.displayname.clone())
.await?; .await?;
services() services
.users .users
.set_avatar_url(&body.user_id, response.avatar_url.clone()) .set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?; .await?;
services() services
.users .users
.set_blurhash(&body.user_id, response.blurhash.clone()) .set_blurhash(&body.user_id, response.blurhash.clone())
.await?; .await?;
@ -91,14 +92,14 @@ pub(crate) async fn get_displayname_route(
} }
} }
if !services().users.exists(&body.user_id)? { if !services.users.exists(&body.user_id)? {
// Return 404 if this user doesn't exist and we couldn't fetch it over // Return 404 if this user doesn't exist and we couldn't fetch it over
// federation // federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
} }
Ok(get_display_name::v3::Response { Ok(get_display_name::v3::Response {
displayname: services().users.displayname(&body.user_id)?, displayname: services.users.displayname(&body.user_id)?,
}) })
} }
@ -108,10 +109,10 @@ pub(crate) async fn get_displayname_route(
/// ///
/// - Also makes sure other users receive the update using presence EDUs /// - Also makes sure other users receive the update using presence EDUs
pub(crate) async fn set_avatar_url_route( pub(crate) async fn set_avatar_url_route(
body: Ruma<set_avatar_url::v3::Request>, State(services): State<crate::State>, body: Ruma<set_avatar_url::v3::Request>,
) -> Result<set_avatar_url::v3::Response> { ) -> Result<set_avatar_url::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let all_joined_rooms: Vec<OwnedRoomId> = services() let all_joined_rooms: Vec<OwnedRoomId> = services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(sender_user) .rooms_joined(sender_user)
@ -119,6 +120,7 @@ pub(crate) async fn set_avatar_url_route(
.collect(); .collect();
update_avatar_url( update_avatar_url(
services,
sender_user.clone(), sender_user.clone(),
body.avatar_url.clone(), body.avatar_url.clone(),
body.blurhash.clone(), body.blurhash.clone(),
@ -126,9 +128,9 @@ pub(crate) async fn set_avatar_url_route(
) )
.await?; .await?;
if services().globals.allow_local_presence() { if services.globals.allow_local_presence() {
// Presence update // Presence update
services() services
.presence .presence
.ping_presence(sender_user, &PresenceState::Online)?; .ping_presence(sender_user, &PresenceState::Online)?;
} }
@ -143,11 +145,11 @@ pub(crate) async fn set_avatar_url_route(
/// - If user is on another server and we do not have a local copy already fetch /// - If user is on another server and we do not have a local copy already fetch
/// `avatar_url` and blurhash over federation /// `avatar_url` and blurhash over federation
pub(crate) async fn get_avatar_url_route( pub(crate) async fn get_avatar_url_route(
body: Ruma<get_avatar_url::v3::Request>, State(services): State<crate::State>, body: Ruma<get_avatar_url::v3::Request>,
) -> Result<get_avatar_url::v3::Response> { ) -> Result<get_avatar_url::v3::Response> {
if !user_is_local(&body.user_id) { if !user_is_local(&body.user_id) {
// Create and update our local copy of the user // Create and update our local copy of the user
if let Ok(response) = services() if let Ok(response) = services
.sending .sending
.send_federation_request( .send_federation_request(
body.user_id.server_name(), body.user_id.server_name(),
@ -158,19 +160,19 @@ pub(crate) async fn get_avatar_url_route(
) )
.await .await
{ {
if !services().users.exists(&body.user_id)? { if !services.users.exists(&body.user_id)? {
services().users.create(&body.user_id, None)?; services.users.create(&body.user_id, None)?;
} }
services() services
.users .users
.set_displayname(&body.user_id, response.displayname.clone()) .set_displayname(&body.user_id, response.displayname.clone())
.await?; .await?;
services() services
.users .users
.set_avatar_url(&body.user_id, response.avatar_url.clone()) .set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?; .await?;
services() services
.users .users
.set_blurhash(&body.user_id, response.blurhash.clone()) .set_blurhash(&body.user_id, response.blurhash.clone())
.await?; .await?;
@ -182,15 +184,15 @@ pub(crate) async fn get_avatar_url_route(
} }
} }
if !services().users.exists(&body.user_id)? { if !services.users.exists(&body.user_id)? {
// Return 404 if this user doesn't exist and we couldn't fetch it over // Return 404 if this user doesn't exist and we couldn't fetch it over
// federation // federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
} }
Ok(get_avatar_url::v3::Response { Ok(get_avatar_url::v3::Response {
avatar_url: services().users.avatar_url(&body.user_id)?, avatar_url: services.users.avatar_url(&body.user_id)?,
blurhash: services().users.blurhash(&body.user_id)?, blurhash: services.users.blurhash(&body.user_id)?,
}) })
} }
@ -200,10 +202,12 @@ pub(crate) async fn get_avatar_url_route(
/// ///
/// - If user is on another server and we do not have a local copy already, /// - If user is on another server and we do not have a local copy already,
/// fetch profile over federation. /// fetch profile over federation.
pub(crate) async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> Result<get_profile::v3::Response> { pub(crate) async fn get_profile_route(
State(services): State<crate::State>, body: Ruma<get_profile::v3::Request>,
) -> Result<get_profile::v3::Response> {
if !user_is_local(&body.user_id) { if !user_is_local(&body.user_id) {
// Create and update our local copy of the user // Create and update our local copy of the user
if let Ok(response) = services() if let Ok(response) = services
.sending .sending
.send_federation_request( .send_federation_request(
body.user_id.server_name(), body.user_id.server_name(),
@ -214,19 +218,19 @@ pub(crate) async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> R
) )
.await .await
{ {
if !services().users.exists(&body.user_id)? { if !services.users.exists(&body.user_id)? {
services().users.create(&body.user_id, None)?; services.users.create(&body.user_id, None)?;
} }
services() services
.users .users
.set_displayname(&body.user_id, response.displayname.clone()) .set_displayname(&body.user_id, response.displayname.clone())
.await?; .await?;
services() services
.users .users
.set_avatar_url(&body.user_id, response.avatar_url.clone()) .set_avatar_url(&body.user_id, response.avatar_url.clone())
.await?; .await?;
services() services
.users .users
.set_blurhash(&body.user_id, response.blurhash.clone()) .set_blurhash(&body.user_id, response.blurhash.clone())
.await?; .await?;
@ -239,23 +243,23 @@ pub(crate) async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> R
} }
} }
if !services().users.exists(&body.user_id)? { if !services.users.exists(&body.user_id)? {
// Return 404 if this user doesn't exist and we couldn't fetch it over // Return 404 if this user doesn't exist and we couldn't fetch it over
// federation // federation
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
} }
Ok(get_profile::v3::Response { Ok(get_profile::v3::Response {
avatar_url: services().users.avatar_url(&body.user_id)?, avatar_url: services.users.avatar_url(&body.user_id)?,
blurhash: services().users.blurhash(&body.user_id)?, blurhash: services.users.blurhash(&body.user_id)?,
displayname: services().users.displayname(&body.user_id)?, displayname: services.users.displayname(&body.user_id)?,
}) })
} }
pub async fn update_displayname( pub async fn update_displayname(
user_id: OwnedUserId, displayname: Option<String>, all_joined_rooms: Vec<OwnedRoomId>, services: &Services, user_id: OwnedUserId, displayname: Option<String>, all_joined_rooms: Vec<OwnedRoomId>,
) -> Result<()> { ) -> Result<()> {
services() services
.users .users
.set_displayname(&user_id, displayname.clone()) .set_displayname(&user_id, displayname.clone())
.await?; .await?;
@ -271,7 +275,7 @@ pub async fn update_displayname(
displayname: displayname.clone(), displayname: displayname.clone(),
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
..serde_json::from_str( ..serde_json::from_str(
services() services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?
@ -294,19 +298,20 @@ pub async fn update_displayname(
.filter_map(Result::ok) .filter_map(Result::ok)
.collect(); .collect();
update_all_rooms(all_joined_rooms, user_id).await; update_all_rooms(services, all_joined_rooms, user_id).await;
Ok(()) Ok(())
} }
pub async fn update_avatar_url( pub async fn update_avatar_url(
user_id: OwnedUserId, avatar_url: Option<OwnedMxcUri>, blurhash: Option<String>, all_joined_rooms: Vec<OwnedRoomId>, services: &Services, user_id: OwnedUserId, avatar_url: Option<OwnedMxcUri>, blurhash: Option<String>,
all_joined_rooms: Vec<OwnedRoomId>,
) -> Result<()> { ) -> Result<()> {
services() services
.users .users
.set_avatar_url(&user_id, avatar_url.clone()) .set_avatar_url(&user_id, avatar_url.clone())
.await?; .await?;
services() services
.users .users
.set_blurhash(&user_id, blurhash.clone()) .set_blurhash(&user_id, blurhash.clone())
.await?; .await?;
@ -323,7 +328,7 @@ pub async fn update_avatar_url(
blurhash: blurhash.clone(), blurhash: blurhash.clone(),
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
..serde_json::from_str( ..serde_json::from_str(
services() services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?
@ -346,15 +351,17 @@ pub async fn update_avatar_url(
.filter_map(Result::ok) .filter_map(Result::ok)
.collect(); .collect();
update_all_rooms(all_joined_rooms, user_id).await; update_all_rooms(services, all_joined_rooms, user_id).await;
Ok(()) Ok(())
} }
pub async fn update_all_rooms(all_joined_rooms: Vec<(PduBuilder, &OwnedRoomId)>, user_id: OwnedUserId) { pub async fn update_all_rooms(
services: &Services, all_joined_rooms: Vec<(PduBuilder, &OwnedRoomId)>, user_id: OwnedUserId,
) {
for (pdu_builder, room_id) in all_joined_rooms { for (pdu_builder, room_id) in all_joined_rooms {
let state_lock = services().rooms.state.mutex.lock(room_id).await; let state_lock = services.rooms.state.mutex.lock(room_id).await;
if let Err(e) = services() if let Err(e) = services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu(pdu_builder, &user_id, room_id, &state_lock) .build_and_append_pdu(pdu_builder, &user_id, room_id, &state_lock)

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -10,18 +11,18 @@ use ruma::{
push::{InsertPushRuleError, RemovePushRuleError, Ruleset}, push::{InsertPushRuleError, RemovePushRuleError, Ruleset},
}; };
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `GET /_matrix/client/r0/pushrules/` /// # `GET /_matrix/client/r0/pushrules/`
/// ///
/// Retrieves the push rules event for this user. /// Retrieves the push rules event for this user.
pub(crate) async fn get_pushrules_all_route( pub(crate) async fn get_pushrules_all_route(
body: Ruma<get_pushrules_all::v3::Request>, State(services): State<crate::State>, body: Ruma<get_pushrules_all::v3::Request>,
) -> Result<get_pushrules_all::v3::Response> { ) -> Result<get_pushrules_all::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = let event =
services() services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?; .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?;
@ -34,7 +35,7 @@ pub(crate) async fn get_pushrules_all_route(
global: account_data.global, global: account_data.global,
}) })
} else { } else {
services().account_data.update( services.account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -55,10 +56,12 @@ pub(crate) async fn get_pushrules_all_route(
/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
/// ///
/// Retrieves a single specified push rule for this user. /// Retrieves a single specified push rule for this user.
pub(crate) async fn get_pushrule_route(body: Ruma<get_pushrule::v3::Request>) -> Result<get_pushrule::v3::Response> { pub(crate) async fn get_pushrule_route(
State(services): State<crate::State>, body: Ruma<get_pushrule::v3::Request>,
) -> Result<get_pushrule::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services() let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
@ -84,7 +87,9 @@ pub(crate) async fn get_pushrule_route(body: Ruma<get_pushrule::v3::Request>) ->
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
/// ///
/// Creates a single specified push rule for this user. /// Creates a single specified push rule for this user.
pub(crate) async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) -> Result<set_pushrule::v3::Response> { pub(crate) async fn set_pushrule_route(
State(services): State<crate::State>, body: Ruma<set_pushrule::v3::Request>,
) -> Result<set_pushrule::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let body = body.body; let body = body.body;
@ -95,7 +100,7 @@ pub(crate) async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) ->
)); ));
} }
let event = services() let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
@ -134,7 +139,7 @@ pub(crate) async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) ->
return Err(err); return Err(err);
} }
services().account_data.update( services.account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -148,7 +153,7 @@ pub(crate) async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) ->
/// ///
/// Gets the actions of a single specified push rule for this user. /// Gets the actions of a single specified push rule for this user.
pub(crate) async fn get_pushrule_actions_route( pub(crate) async fn get_pushrule_actions_route(
body: Ruma<get_pushrule_actions::v3::Request>, State(services): State<crate::State>, body: Ruma<get_pushrule_actions::v3::Request>,
) -> Result<get_pushrule_actions::v3::Response> { ) -> Result<get_pushrule_actions::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -159,7 +164,7 @@ pub(crate) async fn get_pushrule_actions_route(
)); ));
} }
let event = services() let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
@ -183,7 +188,7 @@ pub(crate) async fn get_pushrule_actions_route(
/// ///
/// Sets the actions of a single specified push rule for this user. /// Sets the actions of a single specified push rule for this user.
pub(crate) async fn set_pushrule_actions_route( pub(crate) async fn set_pushrule_actions_route(
body: Ruma<set_pushrule_actions::v3::Request>, State(services): State<crate::State>, body: Ruma<set_pushrule_actions::v3::Request>,
) -> Result<set_pushrule_actions::v3::Response> { ) -> Result<set_pushrule_actions::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -194,7 +199,7 @@ pub(crate) async fn set_pushrule_actions_route(
)); ));
} }
let event = services() let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
@ -211,7 +216,7 @@ pub(crate) async fn set_pushrule_actions_route(
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
} }
services().account_data.update( services.account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -225,7 +230,7 @@ pub(crate) async fn set_pushrule_actions_route(
/// ///
/// Gets the enabled status of a single specified push rule for this user. /// Gets the enabled status of a single specified push rule for this user.
pub(crate) async fn get_pushrule_enabled_route( pub(crate) async fn get_pushrule_enabled_route(
body: Ruma<get_pushrule_enabled::v3::Request>, State(services): State<crate::State>, body: Ruma<get_pushrule_enabled::v3::Request>,
) -> Result<get_pushrule_enabled::v3::Response> { ) -> Result<get_pushrule_enabled::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -236,7 +241,7 @@ pub(crate) async fn get_pushrule_enabled_route(
)); ));
} }
let event = services() let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
@ -259,7 +264,7 @@ pub(crate) async fn get_pushrule_enabled_route(
/// ///
/// Sets the enabled status of a single specified push rule for this user. /// Sets the enabled status of a single specified push rule for this user.
pub(crate) async fn set_pushrule_enabled_route( pub(crate) async fn set_pushrule_enabled_route(
body: Ruma<set_pushrule_enabled::v3::Request>, State(services): State<crate::State>, body: Ruma<set_pushrule_enabled::v3::Request>,
) -> Result<set_pushrule_enabled::v3::Response> { ) -> Result<set_pushrule_enabled::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -270,7 +275,7 @@ pub(crate) async fn set_pushrule_enabled_route(
)); ));
} }
let event = services() let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
@ -287,7 +292,7 @@ pub(crate) async fn set_pushrule_enabled_route(
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
} }
services().account_data.update( services.account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -301,7 +306,7 @@ pub(crate) async fn set_pushrule_enabled_route(
/// ///
/// Deletes a single specified push rule for this user. /// Deletes a single specified push rule for this user.
pub(crate) async fn delete_pushrule_route( pub(crate) async fn delete_pushrule_route(
body: Ruma<delete_pushrule::v3::Request>, State(services): State<crate::State>, body: Ruma<delete_pushrule::v3::Request>,
) -> Result<delete_pushrule::v3::Response> { ) -> Result<delete_pushrule::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -312,7 +317,7 @@ pub(crate) async fn delete_pushrule_route(
)); ));
} }
let event = services() let event = services
.account_data .account_data
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
@ -336,7 +341,7 @@ pub(crate) async fn delete_pushrule_route(
return Err(err); return Err(err);
} }
services().account_data.update( services.account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -349,11 +354,13 @@ pub(crate) async fn delete_pushrule_route(
/// # `GET /_matrix/client/r0/pushers` /// # `GET /_matrix/client/r0/pushers`
/// ///
/// Gets all currently active pushers for the sender user. /// Gets all currently active pushers for the sender user.
pub(crate) async fn get_pushers_route(body: Ruma<get_pushers::v3::Request>) -> Result<get_pushers::v3::Response> { pub(crate) async fn get_pushers_route(
State(services): State<crate::State>, body: Ruma<get_pushers::v3::Request>,
) -> Result<get_pushers::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(get_pushers::v3::Response { Ok(get_pushers::v3::Response {
pushers: services().pusher.get_pushers(sender_user)?, pushers: services.pusher.get_pushers(sender_user)?,
}) })
} }
@ -362,10 +369,12 @@ pub(crate) async fn get_pushers_route(body: Ruma<get_pushers::v3::Request>) -> R
/// Adds a pusher for the sender user. /// Adds a pusher for the sender user.
/// ///
/// - TODO: Handle `append` /// - TODO: Handle `append`
pub(crate) async fn set_pushers_route(body: Ruma<set_pusher::v3::Request>) -> Result<set_pusher::v3::Response> { pub(crate) async fn set_pushers_route(
State(services): State<crate::State>, body: Ruma<set_pusher::v3::Request>,
) -> Result<set_pusher::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services().pusher.set_pusher(sender_user, &body.action)?; services.pusher.set_pusher(sender_user, &body.action)?;
Ok(set_pusher::v3::Response::default()) Ok(set_pusher::v3::Response::default())
} }

View file

@ -1,5 +1,6 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State;
use conduit::PduCount; use conduit::PduCount;
use ruma::{ use ruma::{
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
@ -10,7 +11,7 @@ use ruma::{
MilliSecondsSinceUnixEpoch, MilliSecondsSinceUnixEpoch,
}; };
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers` /// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers`
/// ///
@ -20,7 +21,7 @@ use crate::{services, Error, Result, Ruma};
/// - If `read_receipt` is set: Update private marker and public read receipt /// - If `read_receipt` is set: Update private marker and public read receipt
/// EDU /// EDU
pub(crate) async fn set_read_marker_route( pub(crate) async fn set_read_marker_route(
body: Ruma<set_read_marker::v3::Request>, State(services): State<crate::State>, body: Ruma<set_read_marker::v3::Request>,
) -> Result<set_read_marker::v3::Response> { ) -> Result<set_read_marker::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -30,7 +31,7 @@ pub(crate) async fn set_read_marker_route(
event_id: fully_read.clone(), event_id: fully_read.clone(),
}, },
}; };
services().account_data.update( services.account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::FullyRead, RoomAccountDataEventType::FullyRead,
@ -39,14 +40,14 @@ pub(crate) async fn set_read_marker_route(
} }
if body.private_read_receipt.is_some() || body.read_receipt.is_some() { if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
services() services
.rooms .rooms
.user .user
.reset_notification_counts(sender_user, &body.room_id)?; .reset_notification_counts(sender_user, &body.room_id)?;
} }
if let Some(event) = &body.private_read_receipt { if let Some(event) = &body.private_read_receipt {
let count = services() let count = services
.rooms .rooms
.timeline .timeline
.get_pdu_count(event)? .get_pdu_count(event)?
@ -60,7 +61,7 @@ pub(crate) async fn set_read_marker_route(
}, },
PduCount::Normal(c) => c, PduCount::Normal(c) => c,
}; };
services() services
.rooms .rooms
.read_receipt .read_receipt
.private_read_set(&body.room_id, sender_user, count)?; .private_read_set(&body.room_id, sender_user, count)?;
@ -82,7 +83,7 @@ pub(crate) async fn set_read_marker_route(
let mut receipt_content = BTreeMap::new(); let mut receipt_content = BTreeMap::new();
receipt_content.insert(event.to_owned(), receipts); receipt_content.insert(event.to_owned(), receipts);
services().rooms.read_receipt.readreceipt_update( services.rooms.read_receipt.readreceipt_update(
sender_user, sender_user,
&body.room_id, &body.room_id,
&ruma::events::receipt::ReceiptEvent { &ruma::events::receipt::ReceiptEvent {
@ -99,7 +100,7 @@ pub(crate) async fn set_read_marker_route(
/// ///
/// Sets private read marker and public read receipt EDU. /// Sets private read marker and public read receipt EDU.
pub(crate) async fn create_receipt_route( pub(crate) async fn create_receipt_route(
body: Ruma<create_receipt::v3::Request>, State(services): State<crate::State>, body: Ruma<create_receipt::v3::Request>,
) -> Result<create_receipt::v3::Response> { ) -> Result<create_receipt::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -107,7 +108,7 @@ pub(crate) async fn create_receipt_route(
&body.receipt_type, &body.receipt_type,
create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate
) { ) {
services() services
.rooms .rooms
.user .user
.reset_notification_counts(sender_user, &body.room_id)?; .reset_notification_counts(sender_user, &body.room_id)?;
@ -120,7 +121,7 @@ pub(crate) async fn create_receipt_route(
event_id: body.event_id.clone(), event_id: body.event_id.clone(),
}, },
}; };
services().account_data.update( services.account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::FullyRead, RoomAccountDataEventType::FullyRead,
@ -142,7 +143,7 @@ pub(crate) async fn create_receipt_route(
let mut receipt_content = BTreeMap::new(); let mut receipt_content = BTreeMap::new();
receipt_content.insert(body.event_id.clone(), receipts); receipt_content.insert(body.event_id.clone(), receipts);
services().rooms.read_receipt.readreceipt_update( services.rooms.read_receipt.readreceipt_update(
sender_user, sender_user,
&body.room_id, &body.room_id,
&ruma::events::receipt::ReceiptEvent { &ruma::events::receipt::ReceiptEvent {
@ -152,7 +153,7 @@ pub(crate) async fn create_receipt_route(
)?; )?;
}, },
create_receipt::v3::ReceiptType::ReadPrivate => { create_receipt::v3::ReceiptType::ReadPrivate => {
let count = services() let count = services
.rooms .rooms
.timeline .timeline
.get_pdu_count(&body.event_id)? .get_pdu_count(&body.event_id)?
@ -166,7 +167,7 @@ pub(crate) async fn create_receipt_route(
}, },
PduCount::Normal(c) => c, PduCount::Normal(c) => c,
}; };
services() services
.rooms .rooms
.read_receipt .read_receipt
.private_read_set(&body.room_id, sender_user, count)?; .private_read_set(&body.room_id, sender_user, count)?;

View file

@ -1,23 +1,26 @@
use axum::extract::State;
use ruma::{ use ruma::{
api::client::redact::redact_event, api::client::redact::redact_event,
events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, events::{room::redaction::RoomRedactionEventContent, TimelineEventType},
}; };
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use crate::{service::pdu::PduBuilder, services, Result, Ruma}; use crate::{service::pdu::PduBuilder, Result, Ruma};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}`
/// ///
/// Tries to send a redaction event into the room. /// Tries to send a redaction event into the room.
/// ///
/// - TODO: Handle txn id /// - TODO: Handle txn id
pub(crate) async fn redact_event_route(body: Ruma<redact_event::v3::Request>) -> Result<redact_event::v3::Response> { pub(crate) async fn redact_event_route(
State(services): State<crate::State>, body: Ruma<redact_event::v3::Request>,
) -> Result<redact_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let body = body.body; let body = body.body;
let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
let event_id = services() let event_id = services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(

View file

@ -1,19 +1,17 @@
use axum::extract::State;
use ruma::api::client::relations::{ use ruma::api::client::relations::{
get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type,
}; };
use crate::{services, Result, Ruma}; use crate::{Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}`
pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>, State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type_and_event_type::v1::Request>,
) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> { ) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let res = services() let res = services.rooms.pdu_metadata.paginate_relations_with_filter(
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_id, &body.event_id,
@ -36,14 +34,11 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route(
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}` /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}`
pub(crate) async fn get_relating_events_with_rel_type_route( pub(crate) async fn get_relating_events_with_rel_type_route(
body: Ruma<get_relating_events_with_rel_type::v1::Request>, State(services): State<crate::State>, body: Ruma<get_relating_events_with_rel_type::v1::Request>,
) -> Result<get_relating_events_with_rel_type::v1::Response> { ) -> Result<get_relating_events_with_rel_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let res = services() let res = services.rooms.pdu_metadata.paginate_relations_with_filter(
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_id, &body.event_id,
@ -66,14 +61,11 @@ pub(crate) async fn get_relating_events_with_rel_type_route(
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}` /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}`
pub(crate) async fn get_relating_events_route( pub(crate) async fn get_relating_events_route(
body: Ruma<get_relating_events::v1::Request>, State(services): State<crate::State>, body: Ruma<get_relating_events::v1::Request>,
) -> Result<get_relating_events::v1::Response> { ) -> Result<get_relating_events::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
services() services.rooms.pdu_metadata.paginate_relations_with_filter(
.rooms
.pdu_metadata
.paginate_relations_with_filter(
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_id, &body.event_id,

View file

@ -1,5 +1,6 @@
use std::time::Duration; use std::time::Duration;
use axum::extract::State;
use rand::Rng; use rand::Rng;
use ruma::{ use ruma::{
api::client::{error::ErrorKind, room::report_content}, api::client::{error::ErrorKind, room::report_content},
@ -9,13 +10,18 @@ use ruma::{
use tokio::time::sleep; use tokio::time::sleep;
use tracing::info; use tracing::info;
use crate::{debug_info, service::pdu::PduEvent, services, utils::HtmlEscape, Error, Result, Ruma}; use crate::{
debug_info,
service::{pdu::PduEvent, Services},
utils::HtmlEscape,
Error, Result, Ruma,
};
/// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}` /// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}`
/// ///
/// Reports an inappropriate event to homeserver admins /// Reports an inappropriate event to homeserver admins
pub(crate) async fn report_event_route( pub(crate) async fn report_event_route(
body: Ruma<report_content::v3::Request>, State(services): State<crate::State>, body: Ruma<report_content::v3::Request>,
) -> Result<report_content::v3::Response> { ) -> Result<report_content::v3::Response> {
// user authentication // user authentication
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -26,18 +32,26 @@ pub(crate) async fn report_event_route(
); );
// check if we know about the reported event ID or if it's invalid // check if we know about the reported event ID or if it's invalid
let Some(pdu) = services().rooms.timeline.get_pdu(&body.event_id)? else { let Some(pdu) = services.rooms.timeline.get_pdu(&body.event_id)? else {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::NotFound, ErrorKind::NotFound,
"Event ID is not known to us or Event ID is invalid", "Event ID is not known to us or Event ID is invalid",
)); ));
}; };
is_report_valid(&pdu.event_id, &body.room_id, sender_user, &body.reason, body.score, &pdu)?; is_report_valid(
services,
&pdu.event_id,
&body.room_id,
sender_user,
&body.reason,
body.score,
&pdu,
)?;
// send admin room message that we received the report with an @room ping for // send admin room message that we received the report with an @room ping for
// urgency // urgency
services() services
.admin .admin
.send_message(message::RoomMessageEventContent::text_html( .send_message(message::RoomMessageEventContent::text_html(
format!( format!(
@ -79,8 +93,8 @@ pub(crate) async fn report_event_route(
/// check if score is in valid range /// check if score is in valid range
/// check if report reasoning is less than or equal to 750 characters /// check if report reasoning is less than or equal to 750 characters
fn is_report_valid( fn is_report_valid(
event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option<String>, score: Option<ruma::Int>, services: &Services, event_id: &EventId, room_id: &RoomId, sender_user: &UserId, reason: &Option<String>,
pdu: &std::sync::Arc<PduEvent>, score: Option<ruma::Int>, pdu: &std::sync::Arc<PduEvent>,
) -> Result<bool> { ) -> Result<bool> {
debug_info!("Checking if report from user {sender_user} for event {event_id} in room {room_id} is valid"); debug_info!("Checking if report from user {sender_user} for event {event_id} in room {room_id} is valid");
@ -91,7 +105,7 @@ fn is_report_valid(
)); ));
} }
if !services() if !services
.rooms .rooms
.state_cache .state_cache
.room_members(&pdu.room_id) .room_members(&pdu.room_id)

View file

@ -1,5 +1,6 @@
use std::{cmp::max, collections::BTreeMap}; use std::{cmp::max, collections::BTreeMap};
use axum::extract::State;
use conduit::{debug_info, debug_warn}; use conduit::{debug_info, debug_warn};
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -30,8 +31,8 @@ use tracing::{error, info, warn};
use super::invite_helper; use super::invite_helper;
use crate::{ use crate::{
service::{appservice::RegistrationInfo, pdu::PduBuilder}, service::{appservice::RegistrationInfo, pdu::PduBuilder, Services},
services, Error, Result, Ruma, Error, Result, Ruma,
}; };
/// Recommended transferable state events list from the spec /// Recommended transferable state events list from the spec
@ -63,44 +64,46 @@ const TRANSFERABLE_STATE_EVENTS: &[StateEventType; 9] = &[
/// - Send events listed in initial state /// - Send events listed in initial state
/// - Send events implied by `name` and `topic` /// - Send events implied by `name` and `topic`
/// - Send invite events /// - Send invite events
pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<create_room::v3::Response> { pub(crate) async fn create_room_route(
State(services): State<crate::State>, body: Ruma<create_room::v3::Request>,
) -> Result<create_room::v3::Response> {
use create_room::v3::RoomPreset; use create_room::v3::RoomPreset;
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services().globals.allow_room_creation() if !services.globals.allow_room_creation()
&& body.appservice_info.is_none() && body.appservice_info.is_none()
&& !services().users.is_admin(sender_user)? && !services.users.is_admin(sender_user)?
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Room creation has been disabled.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Room creation has been disabled."));
} }
let room_id: OwnedRoomId = if let Some(custom_room_id) = &body.room_id { let room_id: OwnedRoomId = if let Some(custom_room_id) = &body.room_id {
custom_room_id_check(custom_room_id)? custom_room_id_check(services, custom_room_id)?
} else { } else {
RoomId::new(&services().globals.config.server_name) RoomId::new(&services.globals.config.server_name)
}; };
// check if room ID doesn't already exist instead of erroring on auth check // check if room ID doesn't already exist instead of erroring on auth check
if services().rooms.short.get_shortroomid(&room_id)?.is_some() { if services.rooms.short.get_shortroomid(&room_id)?.is_some() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::RoomInUse, ErrorKind::RoomInUse,
"Room with that custom room ID already exists", "Room with that custom room ID already exists",
)); ));
} }
let _short_id = services().rooms.short.get_or_create_shortroomid(&room_id)?; let _short_id = services.rooms.short.get_or_create_shortroomid(&room_id)?;
let state_lock = services().rooms.state.mutex.lock(&room_id).await; let state_lock = services.rooms.state.mutex.lock(&room_id).await;
let alias: Option<OwnedRoomAliasId> = if let Some(alias) = &body.room_alias_name { let alias: Option<OwnedRoomAliasId> = if let Some(alias) = &body.room_alias_name {
Some(room_alias_check(alias, &body.appservice_info).await?) Some(room_alias_check(services, alias, &body.appservice_info).await?)
} else { } else {
None None
}; };
let room_version = match body.room_version.clone() { let room_version = match body.room_version.clone() {
Some(room_version) => { Some(room_version) => {
if services() if services
.globals .globals
.supported_room_versions() .supported_room_versions()
.contains(&room_version) .contains(&room_version)
@ -113,7 +116,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
)); ));
} }
}, },
None => services().globals.default_room_version(), None => services.globals.default_room_version(),
}; };
let content = match &body.creation_content { let content = match &body.creation_content {
@ -184,7 +187,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
}; };
// 1. The room create event // 1. The room create event
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -202,7 +205,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
.await?; .await?;
// 2. Let the room creator join // 2. Let the room creator join
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -210,11 +213,11 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join, membership: MembershipState::Join,
displayname: services().users.displayname(sender_user)?, displayname: services.users.displayname(sender_user)?,
avatar_url: services().users.avatar_url(sender_user)?, avatar_url: services.users.avatar_url(sender_user)?,
is_direct: Some(body.is_direct), is_direct: Some(body.is_direct),
third_party_invite: None, third_party_invite: None,
blurhash: services().users.blurhash(sender_user)?, blurhash: services.users.blurhash(sender_user)?,
reason: None, reason: None,
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
}) })
@ -249,7 +252,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
let power_levels_content = let power_levels_content =
default_power_levels_content(&body.power_level_content_override, &body.visibility, users)?; default_power_levels_content(&body.power_level_content_override, &body.visibility, users)?;
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -268,7 +271,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
// 4. Canonical room alias // 4. Canonical room alias
if let Some(room_alias_id) = &alias { if let Some(room_alias_id) = &alias {
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -293,7 +296,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
// 5. Events set by preset // 5. Events set by preset
// 5.1 Join Rules // 5.1 Join Rules
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -316,7 +319,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
.await?; .await?;
// 5.2 History Visibility // 5.2 History Visibility
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -335,7 +338,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
.await?; .await?;
// 5.3 Guest Access // 5.3 Guest Access
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -378,11 +381,11 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
pdu_builder.state_key.get_or_insert_with(String::new); pdu_builder.state_key.get_or_insert_with(String::new);
// Silently skip encryption events if they are not allowed // Silently skip encryption events if they are not allowed
if pdu_builder.event_type == TimelineEventType::RoomEncryption && !services().globals.allow_encryption() { if pdu_builder.event_type == TimelineEventType::RoomEncryption && !services.globals.allow_encryption() {
continue; continue;
} }
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
@ -391,7 +394,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
// 7. Events implied by name and topic // 7. Events implied by name and topic
if let Some(name) = &body.name { if let Some(name) = &body.name {
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -411,7 +414,7 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
} }
if let Some(topic) = &body.topic { if let Some(topic) = &body.topic {
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -435,21 +438,21 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
// 8. Events implied by invite (and TODO: invite_3pid) // 8. Events implied by invite (and TODO: invite_3pid)
drop(state_lock); drop(state_lock);
for user_id in &body.invite { for user_id in &body.invite {
if let Err(e) = invite_helper(sender_user, user_id, &room_id, None, body.is_direct).await { if let Err(e) = invite_helper(services, sender_user, user_id, &room_id, None, body.is_direct).await {
warn!(%e, "Failed to send invite"); warn!(%e, "Failed to send invite");
} }
} }
// Homeserver specific stuff // Homeserver specific stuff
if let Some(alias) = alias { if let Some(alias) = alias {
services() services
.rooms .rooms
.alias .alias
.set_alias(&alias, &room_id, sender_user)?; .set_alias(&alias, &room_id, sender_user)?;
} }
if body.visibility == room::Visibility::Public { if body.visibility == room::Visibility::Public {
services().rooms.directory.set_public(&room_id)?; services.rooms.directory.set_public(&room_id)?;
} }
info!("{sender_user} created a room with room ID {room_id}"); info!("{sender_user} created a room with room ID {room_id}");
@ -464,11 +467,11 @@ pub(crate) async fn create_room_route(body: Ruma<create_room::v3::Request>) -> R
/// - You have to currently be joined to the room (TODO: Respect history /// - You have to currently be joined to the room (TODO: Respect history
/// visibility) /// visibility)
pub(crate) async fn get_room_event_route( pub(crate) async fn get_room_event_route(
body: Ruma<get_room_event::v3::Request>, State(services): State<crate::State>, body: Ruma<get_room_event::v3::Request>,
) -> Result<get_room_event::v3::Response> { ) -> Result<get_room_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services() let event = services
.rooms .rooms
.timeline .timeline
.get_pdu(&body.event_id)? .get_pdu(&body.event_id)?
@ -477,7 +480,7 @@ pub(crate) async fn get_room_event_route(
Error::BadRequest(ErrorKind::NotFound, "Event not found.") Error::BadRequest(ErrorKind::NotFound, "Event not found.")
})?; })?;
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &event.room_id, &body.event_id)? .user_can_see_event(sender_user, &event.room_id, &body.event_id)?
@ -502,10 +505,12 @@ pub(crate) async fn get_room_event_route(
/// ///
/// - Only users joined to the room are allowed to call this, or if /// - Only users joined to the room are allowed to call this, or if
/// `history_visibility` is world readable in the room /// `history_visibility` is world readable in the room
pub(crate) async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) -> Result<aliases::v3::Response> { pub(crate) async fn get_room_aliases_route(
State(services): State<crate::State>, body: Ruma<aliases::v3::Request>,
) -> Result<aliases::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_state_events(sender_user, &body.room_id)? .user_can_see_state_events(sender_user, &body.room_id)?
@ -517,7 +522,7 @@ pub(crate) async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) ->
} }
Ok(aliases::v3::Response { Ok(aliases::v3::Response {
aliases: services() aliases: services
.rooms .rooms
.alias .alias
.local_aliases_for_room(&body.room_id) .local_aliases_for_room(&body.room_id)
@ -536,10 +541,12 @@ pub(crate) async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) ->
/// - Transfers some state events /// - Transfers some state events
/// - Moves local aliases /// - Moves local aliases
/// - Modifies old room power levels to prevent users from speaking /// - Modifies old room power levels to prevent users from speaking
pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result<upgrade_room::v3::Response> { pub(crate) async fn upgrade_room_route(
State(services): State<crate::State>, body: Ruma<upgrade_room::v3::Request>,
) -> Result<upgrade_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services
.globals .globals
.supported_room_versions() .supported_room_versions()
.contains(&body.new_version) .contains(&body.new_version)
@ -551,19 +558,19 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) ->
} }
// Create a replacement room // Create a replacement room
let replacement_room = RoomId::new(services().globals.server_name()); let replacement_room = RoomId::new(services.globals.server_name());
let _short_id = services() let _short_id = services
.rooms .rooms
.short .short
.get_or_create_shortroomid(&replacement_room)?; .get_or_create_shortroomid(&replacement_room)?;
let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
// Send a m.room.tombstone event to the old room to indicate that it is not // Send a m.room.tombstone event to the old room to indicate that it is not
// intended to be used any further Fail if the sender does not have the required // intended to be used any further Fail if the sender does not have the required
// permissions // permissions
let tombstone_event_id = services() let tombstone_event_id = services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -586,11 +593,11 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) ->
// Change lock to replacement room // Change lock to replacement room
drop(state_lock); drop(state_lock);
let state_lock = services().rooms.state.mutex.lock(&replacement_room).await; let state_lock = services.rooms.state.mutex.lock(&replacement_room).await;
// Get the old room creation event // Get the old room creation event
let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>( let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>(
services() services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")?
@ -658,7 +665,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) ->
return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")); return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"));
} }
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -676,7 +683,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) ->
.await?; .await?;
// Join the new room // Join the new room
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -684,11 +691,11 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) ->
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join, membership: MembershipState::Join,
displayname: services().users.displayname(sender_user)?, displayname: services.users.displayname(sender_user)?,
avatar_url: services().users.avatar_url(sender_user)?, avatar_url: services.users.avatar_url(sender_user)?,
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: services().users.blurhash(sender_user)?, blurhash: services.users.blurhash(sender_user)?,
reason: None, reason: None,
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
}) })
@ -705,7 +712,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) ->
// Replicate transferable state events to the new room // Replicate transferable state events to the new room
for event_type in TRANSFERABLE_STATE_EVENTS { for event_type in TRANSFERABLE_STATE_EVENTS {
let event_content = match services() let event_content = match services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, event_type, "")? .room_state_get(&body.room_id, event_type, "")?
@ -714,7 +721,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) ->
None => continue, // Skipping missing events. None => continue, // Skipping missing events.
}; };
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -733,13 +740,13 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) ->
} }
// Moves any local aliases to the new room // Moves any local aliases to the new room
for alias in services() for alias in services
.rooms .rooms
.alias .alias
.local_aliases_for_room(&body.room_id) .local_aliases_for_room(&body.room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
{ {
services() services
.rooms .rooms
.alias .alias
.set_alias(&alias, &replacement_room, sender_user)?; .set_alias(&alias, &replacement_room, sender_user)?;
@ -747,7 +754,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) ->
// Get the old room power levels // Get the old room power levels
let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str(
services() services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")?
@ -772,7 +779,7 @@ pub(crate) async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) ->
// Modify the power levels in the old room to prevent sending of events and // Modify the power levels in the old room to prevent sending of events and
// inviting new users // inviting new users
services() services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -841,7 +848,7 @@ fn default_power_levels_content(
/// if a room is being created with a room alias, run our checks /// if a room is being created with a room alias, run our checks
async fn room_alias_check( async fn room_alias_check(
room_alias_name: &str, appservice_info: &Option<RegistrationInfo>, services: &Services, room_alias_name: &str, appservice_info: &Option<RegistrationInfo>,
) -> Result<OwnedRoomAliasId> { ) -> Result<OwnedRoomAliasId> {
// Basic checks on the room alias validity // Basic checks on the room alias validity
if room_alias_name.contains(':') { if room_alias_name.contains(':') {
@ -858,7 +865,7 @@ async fn room_alias_check(
} }
// check if room alias is forbidden // check if room alias is forbidden
if services() if services
.globals .globals
.forbidden_alias_names() .forbidden_alias_names()
.is_match(room_alias_name) .is_match(room_alias_name)
@ -866,13 +873,13 @@ async fn room_alias_check(
return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden.")); return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden."));
} }
let full_room_alias = RoomAliasId::parse(format!("#{}:{}", room_alias_name, services().globals.config.server_name)) let full_room_alias = RoomAliasId::parse(format!("#{}:{}", room_alias_name, services.globals.config.server_name))
.map_err(|e| { .map_err(|e| {
info!("Failed to parse room alias {room_alias_name}: {e}"); info!("Failed to parse room alias {room_alias_name}: {e}");
Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.") Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.")
})?; })?;
if services() if services
.rooms .rooms
.alias .alias
.resolve_local_alias(&full_room_alias)? .resolve_local_alias(&full_room_alias)?
@ -885,7 +892,7 @@ async fn room_alias_check(
if !info.aliases.is_match(full_room_alias.as_str()) { if !info.aliases.is_match(full_room_alias.as_str()) {
return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace.")); return Err(Error::BadRequest(ErrorKind::Exclusive, "Room alias is not in namespace."));
} }
} else if services() } else if services
.appservice .appservice
.is_exclusive_alias(&full_room_alias) .is_exclusive_alias(&full_room_alias)
.await .await
@ -899,9 +906,9 @@ async fn room_alias_check(
} }
/// if a room is being created with a custom room ID, run our checks against it /// if a room is being created with a custom room ID, run our checks against it
fn custom_room_id_check(custom_room_id: &str) -> Result<OwnedRoomId> { fn custom_room_id_check(services: &Services, custom_room_id: &str) -> Result<OwnedRoomId> {
// apply forbidden room alias checks to custom room IDs too // apply forbidden room alias checks to custom room IDs too
if services() if services
.globals .globals
.forbidden_alias_names() .forbidden_alias_names()
.is_match(custom_room_id) .is_match(custom_room_id)
@ -922,7 +929,7 @@ fn custom_room_id_check(custom_room_id: &str) -> Result<OwnedRoomId> {
)); ));
} }
let full_room_id = format!("!{}:{}", custom_room_id, services().globals.config.server_name); let full_room_id = format!("!{}:{}", custom_room_id, services.globals.config.server_name);
debug_info!("Full custom room ID: {full_room_id}"); debug_info!("Full custom room ID: {full_room_id}");

View file

@ -1,5 +1,6 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State;
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -14,7 +15,7 @@ use ruma::{
}; };
use tracing::debug; use tracing::debug;
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `POST /_matrix/client/r0/search` /// # `POST /_matrix/client/r0/search`
/// ///
@ -22,7 +23,9 @@ use crate::{services, Error, Result, Ruma};
/// ///
/// - Only works if the user is currently joined to the room (TODO: Respect /// - Only works if the user is currently joined to the room (TODO: Respect
/// history visibility) /// history visibility)
pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Result<search_events::v3::Response> { pub(crate) async fn search_events_route(
State(services): State<crate::State>, body: Ruma<search_events::v3::Request>,
) -> Result<search_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let search_criteria = body.search_categories.room_events.as_ref().unwrap(); let search_criteria = body.search_categories.room_events.as_ref().unwrap();
@ -30,7 +33,7 @@ pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>)
let include_state = &search_criteria.include_state; let include_state = &search_criteria.include_state;
let room_ids = filter.rooms.clone().unwrap_or_else(|| { let room_ids = filter.rooms.clone().unwrap_or_else(|| {
services() services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(sender_user) .rooms_joined(sender_user)
@ -50,11 +53,7 @@ pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>)
if include_state.is_some_and(|include_state| include_state) { if include_state.is_some_and(|include_state| include_state) {
for room_id in &room_ids { for room_id in &room_ids {
if !services() if !services.rooms.state_cache.is_joined(sender_user, room_id)? {
.rooms
.state_cache
.is_joined(sender_user, room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -62,12 +61,12 @@ pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>)
} }
// check if sender_user can see state events // check if sender_user can see state events
if services() if services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_state_events(sender_user, room_id)? .user_can_see_state_events(sender_user, room_id)?
{ {
let room_state = services() let room_state = services
.rooms .rooms
.state_accessor .state_accessor
.room_state_full(room_id) .room_state_full(room_id)
@ -91,18 +90,14 @@ pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>)
let mut searches = Vec::new(); let mut searches = Vec::new();
for room_id in &room_ids { for room_id in &room_ids {
if !services() if !services.rooms.state_cache.is_joined(sender_user, room_id)? {
.rooms
.state_cache
.is_joined(sender_user, room_id)?
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"You don't have permission to view this room.", "You don't have permission to view this room.",
)); ));
} }
if let Some(search) = services() if let Some(search) = services
.rooms .rooms
.search .search
.search_pdus(room_id, &search_criteria.search_term)? .search_pdus(room_id, &search_criteria.search_term)?
@ -135,14 +130,14 @@ pub(crate) async fn search_events_route(body: Ruma<search_events::v3::Request>)
.iter() .iter()
.skip(skip) .skip(skip)
.filter_map(|result| { .filter_map(|result| {
services() services
.rooms .rooms
.timeline .timeline
.get_pdu_from_id(result) .get_pdu_from_id(result)
.ok()? .ok()?
.filter(|pdu| { .filter(|pdu| {
!pdu.is_redacted() !pdu.is_redacted()
&& services() && services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -20,7 +21,7 @@ use serde::Deserialize;
use tracing::{debug, info, warn}; use tracing::{debug, info, warn};
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{services, utils, utils::hash, Error, Result, Ruma}; use crate::{utils, utils::hash, Error, Result, Ruma};
#[derive(Debug, Deserialize)] #[derive(Debug, Deserialize)]
struct Claims { struct Claims {
@ -55,7 +56,9 @@ pub(crate) async fn get_login_types_route(
/// Note: You can use [`GET /// Note: You can use [`GET
/// /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see /// /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see
/// supported login types. /// supported login types.
pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> { pub(crate) async fn login_route(
State(services): State<crate::State>, body: Ruma<login::v3::Request>,
) -> Result<login::v3::Response> {
// Validate login method // Validate login method
// TODO: Other login methods // TODO: Other login methods
let user_id = match &body.login_info { let user_id = match &body.login_info {
@ -68,7 +71,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
}) => { }) => {
debug!("Got password login type"); debug!("Got password login type");
let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
UserId::parse_with_server_name(user_id.to_lowercase(), services().globals.server_name()) UserId::parse_with_server_name(user_id.to_lowercase(), services.globals.server_name())
} else if let Some(user) = user { } else if let Some(user) = user {
UserId::parse(user) UserId::parse(user)
} else { } else {
@ -77,7 +80,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
} }
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
let hash = services() let hash = services
.users .users
.password_hash(&user_id)? .password_hash(&user_id)?
.ok_or(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password."))?; .ok_or(Error::BadRequest(ErrorKind::forbidden(), "Wrong username or password."))?;
@ -96,7 +99,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
token, token,
}) => { }) => {
debug!("Got token login type"); debug!("Got token login type");
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { if let Some(jwt_decoding_key) = services.globals.jwt_decoding_key() {
let token = let token =
jsonwebtoken::decode::<Claims>(token, jwt_decoding_key, &jsonwebtoken::Validation::default()) jsonwebtoken::decode::<Claims>(token, jwt_decoding_key, &jsonwebtoken::Validation::default())
.map_err(|e| { .map_err(|e| {
@ -106,7 +109,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
let username = token.claims.sub.to_lowercase(); let username = token.claims.sub.to_lowercase();
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| { UserId::parse_with_server_name(username, services.globals.server_name()).map_err(|e| {
warn!("Failed to parse username from user logging in: {e}"); warn!("Failed to parse username from user logging in: {e}");
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})? })?
@ -124,7 +127,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
}) => { }) => {
debug!("Got appservice login type"); debug!("Got appservice login type");
let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { let user_id = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
UserId::parse_with_server_name(user_id.to_lowercase(), services().globals.server_name()) UserId::parse_with_server_name(user_id.to_lowercase(), services.globals.server_name())
} else if let Some(user) = user { } else if let Some(user) = user {
UserId::parse(user) UserId::parse(user)
} else { } else {
@ -164,22 +167,22 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
// Determine if device_id was provided and exists in the db for this user // Determine if device_id was provided and exists in the db for this user
let device_exists = body.device_id.as_ref().map_or(false, |device_id| { let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
services() services
.users .users
.all_device_ids(&user_id) .all_device_ids(&user_id)
.any(|x| x.as_ref().map_or(false, |v| v == device_id)) .any(|x| x.as_ref().map_or(false, |v| v == device_id))
}); });
if device_exists { if device_exists {
services().users.set_token(&user_id, &device_id, &token)?; services.users.set_token(&user_id, &device_id, &token)?;
} else { } else {
services() services
.users .users
.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?; .create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
} }
// send client well-known if specified so the client knows to reconfigure itself // send client well-known if specified so the client knows to reconfigure itself
let client_discovery_info: Option<DiscoveryInfo> = services() let client_discovery_info: Option<DiscoveryInfo> = services
.globals .globals
.well_known_client() .well_known_client()
.as_ref() .as_ref()
@ -197,7 +200,7 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
device_id, device_id,
well_known: client_discovery_info, well_known: client_discovery_info,
expires_in: None, expires_in: None,
home_server: Some(services().globals.server_name().to_owned()), home_server: Some(services.globals.server_name().to_owned()),
refresh_token: None, refresh_token: None,
}) })
} }
@ -211,14 +214,16 @@ pub(crate) async fn login_route(body: Ruma<login::v3::Request>) -> Result<login:
/// last seen ts) /// last seen ts)
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub(crate) async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> { pub(crate) async fn logout_route(
State(services): State<crate::State>, body: Ruma<logout::v3::Request>,
) -> Result<logout::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
services().users.remove_device(sender_user, sender_device)?; services.users.remove_device(sender_user, sender_device)?;
// send device list update for user after logout // send device list update for user after logout
services().users.mark_device_key_update(sender_user)?; services.users.mark_device_key_update(sender_user)?;
Ok(logout::v3::Response::new()) Ok(logout::v3::Response::new())
} }
@ -236,15 +241,17 @@ pub(crate) async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logo
/// Note: This is equivalent to calling [`GET /// Note: This is equivalent to calling [`GET
/// /_matrix/client/r0/logout`](fn.logout_route.html) from each device of this /// /_matrix/client/r0/logout`](fn.logout_route.html) from each device of this
/// user. /// user.
pub(crate) async fn logout_all_route(body: Ruma<logout_all::v3::Request>) -> Result<logout_all::v3::Response> { pub(crate) async fn logout_all_route(
State(services): State<crate::State>, body: Ruma<logout_all::v3::Request>,
) -> Result<logout_all::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
for device_id in services().users.all_device_ids(sender_user).flatten() { for device_id in services.users.all_device_ids(sender_user).flatten() {
services().users.remove_device(sender_user, &device_id)?; services.users.remove_device(sender_user, &device_id)?;
} }
// send device list update for user after logout // send device list update for user after logout
services().users.mark_device_key_update(sender_user)?; services.users.mark_device_key_update(sender_user)?;
Ok(logout_all::v3::Response::new()) Ok(logout_all::v3::Response::new())
} }

View file

@ -1,17 +1,20 @@
use std::str::FromStr; use std::str::FromStr;
use axum::extract::State;
use ruma::{ use ruma::{
api::client::{error::ErrorKind, space::get_hierarchy}, api::client::{error::ErrorKind, space::get_hierarchy},
UInt, UInt,
}; };
use crate::{service::rooms::spaces::PaginationToken, services, Error, Result, Ruma}; use crate::{service::rooms::spaces::PaginationToken, Error, Result, Ruma};
/// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy` /// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy`
/// ///
/// Paginates over the space tree in a depth-first manner to locate child rooms /// Paginates over the space tree in a depth-first manner to locate child rooms
/// of a given space. /// of a given space.
pub(crate) async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> { pub(crate) async fn get_hierarchy_route(
State(services): State<crate::State>, body: Ruma<get_hierarchy::v1::Request>,
) -> Result<get_hierarchy::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let limit = body let limit = body
@ -39,7 +42,7 @@ pub(crate) async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>)
} }
} }
services() services
.rooms .rooms
.spaces .spaces
.get_client_hierarchy( .get_client_hierarchy(

View file

@ -1,5 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::State;
use conduit::{debug_info, error}; use conduit::{debug_info, error};
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -19,8 +20,8 @@ use ruma::{
}; };
use crate::{ use crate::{
service::{pdu::PduBuilder, server_is_ours}, service::{pdu::PduBuilder, server_is_ours, Services},
services, Error, Result, Ruma, RumaResponse, Error, Result, Ruma, RumaResponse,
}; };
/// # `PUT /_matrix/client/*/rooms/{roomId}/state/{eventType}/{stateKey}` /// # `PUT /_matrix/client/*/rooms/{roomId}/state/{eventType}/{stateKey}`
@ -32,12 +33,13 @@ use crate::{
/// allowed /// allowed
/// - If event is new `canonical_alias`: Rejects if alias is incorrect /// - If event is new `canonical_alias`: Rejects if alias is incorrect
pub(crate) async fn send_state_event_for_key_route( pub(crate) async fn send_state_event_for_key_route(
body: Ruma<send_state_event::v3::Request>, State(services): State<crate::State>, body: Ruma<send_state_event::v3::Request>,
) -> Result<send_state_event::v3::Response> { ) -> Result<send_state_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(send_state_event::v3::Response { Ok(send_state_event::v3::Response {
event_id: send_state_event_for_key_helper( event_id: send_state_event_for_key_helper(
services,
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_type, &body.event_type,
@ -58,9 +60,11 @@ pub(crate) async fn send_state_event_for_key_route(
/// allowed /// allowed
/// - If event is new `canonical_alias`: Rejects if alias is incorrect /// - If event is new `canonical_alias`: Rejects if alias is incorrect
pub(crate) async fn send_state_event_for_empty_key_route( pub(crate) async fn send_state_event_for_empty_key_route(
body: Ruma<send_state_event::v3::Request>, State(services): State<crate::State>, body: Ruma<send_state_event::v3::Request>,
) -> Result<RumaResponse<send_state_event::v3::Response>> { ) -> Result<RumaResponse<send_state_event::v3::Response>> {
send_state_event_for_key_route(body).await.map(RumaResponse) send_state_event_for_key_route(State(services), body)
.await
.map(RumaResponse)
} }
/// # `GET /_matrix/client/v3/rooms/{roomid}/state` /// # `GET /_matrix/client/v3/rooms/{roomid}/state`
@ -70,11 +74,11 @@ pub(crate) async fn send_state_event_for_empty_key_route(
/// - If not joined: Only works if current room history visibility is world /// - If not joined: Only works if current room history visibility is world
/// readable /// readable
pub(crate) async fn get_state_events_route( pub(crate) async fn get_state_events_route(
body: Ruma<get_state_events::v3::Request>, State(services): State<crate::State>, body: Ruma<get_state_events::v3::Request>,
) -> Result<get_state_events::v3::Response> { ) -> Result<get_state_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_state_events(sender_user, &body.room_id)? .user_can_see_state_events(sender_user, &body.room_id)?
@ -86,7 +90,7 @@ pub(crate) async fn get_state_events_route(
} }
Ok(get_state_events::v3::Response { Ok(get_state_events::v3::Response {
room_state: services() room_state: services
.rooms .rooms
.state_accessor .state_accessor
.room_state_full(&body.room_id) .room_state_full(&body.room_id)
@ -106,11 +110,11 @@ pub(crate) async fn get_state_events_route(
/// - If not joined: Only works if current room history visibility is world /// - If not joined: Only works if current room history visibility is world
/// readable /// readable
pub(crate) async fn get_state_events_for_key_route( pub(crate) async fn get_state_events_for_key_route(
body: Ruma<get_state_events_for_key::v3::Request>, State(services): State<crate::State>, body: Ruma<get_state_events_for_key::v3::Request>,
) -> Result<get_state_events_for_key::v3::Response> { ) -> Result<get_state_events_for_key::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_state_events(sender_user, &body.room_id)? .user_can_see_state_events(sender_user, &body.room_id)?
@ -121,7 +125,7 @@ pub(crate) async fn get_state_events_for_key_route(
)); ));
} }
let event = services() let event = services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&body.room_id, &body.event_type, &body.state_key)? .room_state_get(&body.room_id, &body.event_type, &body.state_key)?
@ -161,17 +165,20 @@ pub(crate) async fn get_state_events_for_key_route(
/// - If not joined: Only works if current room history visibility is world /// - If not joined: Only works if current room history visibility is world
/// readable /// readable
pub(crate) async fn get_state_events_for_empty_key_route( pub(crate) async fn get_state_events_for_empty_key_route(
body: Ruma<get_state_events_for_key::v3::Request>, State(services): State<crate::State>, body: Ruma<get_state_events_for_key::v3::Request>,
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> { ) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
get_state_events_for_key_route(body).await.map(RumaResponse) get_state_events_for_key_route(State(services), body)
.await
.map(RumaResponse)
} }
async fn send_state_event_for_key_helper( async fn send_state_event_for_key_helper(
sender: &UserId, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, state_key: String, services: &Services, sender: &UserId, room_id: &RoomId, event_type: &StateEventType,
json: &Raw<AnyStateEventContent>, state_key: String,
) -> Result<Arc<EventId>> { ) -> Result<Arc<EventId>> {
allowed_to_send_state_event(room_id, event_type, json).await?; allowed_to_send_state_event(services, room_id, event_type, json).await?;
let state_lock = services().rooms.state.mutex.lock(room_id).await; let state_lock = services.rooms.state.mutex.lock(room_id).await;
let event_id = services() let event_id = services
.rooms .rooms
.timeline .timeline
.build_and_append_pdu( .build_and_append_pdu(
@ -192,12 +199,12 @@ async fn send_state_event_for_key_helper(
} }
async fn allowed_to_send_state_event( async fn allowed_to_send_state_event(
room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, services: &Services, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>,
) -> Result<()> { ) -> Result<()> {
match event_type { match event_type {
// Forbid m.room.encryption if encryption is disabled // Forbid m.room.encryption if encryption is disabled
StateEventType::RoomEncryption => { StateEventType::RoomEncryption => {
if !services().globals.allow_encryption() { if !services.globals.allow_encryption() {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Encryption has been disabled"));
} }
}, },
@ -244,7 +251,7 @@ async fn allowed_to_send_state_event(
for alias in aliases { for alias in aliases {
if !server_is_ours(alias.server_name()) if !server_is_ours(alias.server_name())
|| services() || services
.rooms .rooms
.alias .alias
.resolve_local_alias(&alias)? .resolve_local_alias(&alias)?

View file

@ -5,6 +5,7 @@ use std::{
time::Duration, time::Duration,
}; };
use axum::extract::State;
use conduit::{ use conduit::{
error, error,
utils::math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, utils::math::{ruma_from_u64, ruma_from_usize, usize_from_ruma, usize_from_u64_truncated},
@ -17,7 +18,7 @@ use ruma::{
self, self,
v3::{ v3::{
Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, LeftRoom, Presence, Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, LeftRoom, Presence,
RoomAccountData, RoomSummary, Rooms, State, Timeline, ToDevice, RoomAccountData, RoomSummary, Rooms, State as RoomState, Timeline, ToDevice,
}, },
v4::SlidingOp, v4::SlidingOp,
DeviceLists, UnreadNotificationsCount, DeviceLists, UnreadNotificationsCount,
@ -34,7 +35,10 @@ use ruma::{
}; };
use tracing::{Instrument as _, Span}; use tracing::{Instrument as _, Span};
use crate::{service::pdu::EventHash, services, utils, Error, PduEvent, Result, Ruma, RumaResponse}; use crate::{
service::{pdu::EventHash, Services},
utils, Error, PduEvent, Result, Ruma, RumaResponse,
};
/// # `GET /_matrix/client/r0/sync` /// # `GET /_matrix/client/r0/sync`
/// ///
@ -72,23 +76,23 @@ use crate::{service::pdu::EventHash, services, utils, Error, PduEvent, Result, R
/// - If the user left after `since`: `prev_batch` token, empty state (TODO: /// - If the user left after `since`: `prev_batch` token, empty state (TODO:
/// subset of the state at the point of the leave) /// subset of the state at the point of the leave)
pub(crate) async fn sync_events_route( pub(crate) async fn sync_events_route(
body: Ruma<sync_events::v3::Request>, State(services): State<crate::State>, body: Ruma<sync_events::v3::Request>,
) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> { ) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> {
let sender_user = body.sender_user.expect("user is authenticated"); let sender_user = body.sender_user.expect("user is authenticated");
let sender_device = body.sender_device.expect("user is authenticated"); let sender_device = body.sender_device.expect("user is authenticated");
let body = body.body; let body = body.body;
// Presence update // Presence update
if services().globals.allow_local_presence() { if services.globals.allow_local_presence() {
services() services
.presence .presence
.ping_presence(&sender_user, &body.set_presence)?; .ping_presence(&sender_user, &body.set_presence)?;
} }
// Setup watchers, so if there's no response, we can wait for them // Setup watchers, so if there's no response, we can wait for them
let watcher = services().globals.watch(&sender_user, &sender_device); let watcher = services.globals.watch(&sender_user, &sender_device);
let next_batch = services().globals.current_count()?; let next_batch = services.globals.current_count()?;
let next_batchcount = PduCount::Normal(next_batch); let next_batchcount = PduCount::Normal(next_batch);
let next_batch_string = next_batch.to_string(); let next_batch_string = next_batch.to_string();
@ -96,7 +100,7 @@ pub(crate) async fn sync_events_route(
let filter = match body.filter { let filter = match body.filter {
None => FilterDefinition::default(), None => FilterDefinition::default(),
Some(Filter::FilterDefinition(filter)) => filter, Some(Filter::FilterDefinition(filter)) => filter,
Some(Filter::FilterId(filter_id)) => services() Some(Filter::FilterId(filter_id)) => services
.users .users
.get_filter(&sender_user, &filter_id)? .get_filter(&sender_user, &filter_id)?
.unwrap_or_default(), .unwrap_or_default(),
@ -126,28 +130,29 @@ pub(crate) async fn sync_events_route(
// Look for device list updates of this account // Look for device list updates of this account
device_list_updates.extend( device_list_updates.extend(
services() services
.users .users
.keys_changed(sender_user.as_ref(), since, None) .keys_changed(sender_user.as_ref(), since, None)
.filter_map(Result::ok), .filter_map(Result::ok),
); );
if services().globals.allow_local_presence() { if services.globals.allow_local_presence() {
process_presence_updates(&mut presence_updates, since, &sender_user).await?; process_presence_updates(services, &mut presence_updates, since, &sender_user).await?;
} }
let all_joined_rooms = services() let all_joined_rooms = services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&sender_user) .rooms_joined(&sender_user)
.collect::<Vec<_>>(); .collect::<Vec<_>>();
// Coalesce database writes for the remainder of this scope. // Coalesce database writes for the remainder of this scope.
let _cork = services().db.cork_and_flush(); let _cork = services.db.cork_and_flush();
for room_id in all_joined_rooms { for room_id in all_joined_rooms {
let room_id = room_id?; let room_id = room_id?;
if let Ok(joined_room) = load_joined_room( if let Ok(joined_room) = load_joined_room(
services,
&sender_user, &sender_user,
&sender_device, &sender_device,
&room_id, &room_id,
@ -170,13 +175,14 @@ pub(crate) async fn sync_events_route(
} }
let mut left_rooms = BTreeMap::new(); let mut left_rooms = BTreeMap::new();
let all_left_rooms: Vec<_> = services() let all_left_rooms: Vec<_> = services
.rooms .rooms
.state_cache .state_cache
.rooms_left(&sender_user) .rooms_left(&sender_user)
.collect(); .collect();
for result in all_left_rooms { for result in all_left_rooms {
handle_left_room( handle_left_room(
services,
since, since,
&result?.0, &result?.0,
&sender_user, &sender_user,
@ -190,7 +196,7 @@ pub(crate) async fn sync_events_route(
} }
let mut invited_rooms = BTreeMap::new(); let mut invited_rooms = BTreeMap::new();
let all_invited_rooms: Vec<_> = services() let all_invited_rooms: Vec<_> = services
.rooms .rooms
.state_cache .state_cache
.rooms_invited(&sender_user) .rooms_invited(&sender_user)
@ -199,10 +205,10 @@ pub(crate) async fn sync_events_route(
let (room_id, invite_state_events) = result?; let (room_id, invite_state_events) = result?;
// Get and drop the lock to wait for remaining operations to finish // Get and drop the lock to wait for remaining operations to finish
let insert_lock = services().rooms.timeline.mutex_insert.lock(&room_id).await; let insert_lock = services.rooms.timeline.mutex_insert.lock(&room_id).await;
drop(insert_lock); drop(insert_lock);
let invite_count = services() let invite_count = services
.rooms .rooms
.state_cache .state_cache
.get_invite_count(&room_id, &sender_user)?; .get_invite_count(&room_id, &sender_user)?;
@ -223,14 +229,14 @@ pub(crate) async fn sync_events_route(
} }
for user_id in left_encrypted_users { for user_id in left_encrypted_users {
let dont_share_encrypted_room = services() let dont_share_encrypted_room = services
.rooms .rooms
.user .user
.get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])?
.filter_map(Result::ok) .filter_map(Result::ok)
.filter_map(|other_room_id| { .filter_map(|other_room_id| {
Some( Some(
services() services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "")
@ -247,7 +253,7 @@ pub(crate) async fn sync_events_route(
} }
// Remove all to-device events the device received *last time* // Remove all to-device events the device received *last time*
services() services
.users .users
.remove_to_device_events(&sender_user, &sender_device, since)?; .remove_to_device_events(&sender_user, &sender_device, since)?;
@ -266,7 +272,7 @@ pub(crate) async fn sync_events_route(
.collect(), .collect(),
}, },
account_data: GlobalAccountData { account_data: GlobalAccountData {
events: services() events: services
.account_data .account_data
.changes_since(None, &sender_user, since)? .changes_since(None, &sender_user, since)?
.into_iter() .into_iter()
@ -281,11 +287,11 @@ pub(crate) async fn sync_events_route(
changed: device_list_updates.into_iter().collect(), changed: device_list_updates.into_iter().collect(),
left: device_list_left.into_iter().collect(), left: device_list_left.into_iter().collect(),
}, },
device_one_time_keys_count: services() device_one_time_keys_count: services
.users .users
.count_one_time_keys(&sender_user, &sender_device)?, .count_one_time_keys(&sender_user, &sender_device)?,
to_device: ToDevice { to_device: ToDevice {
events: services() events: services
.users .users
.get_to_device_events(&sender_user, &sender_device)?, .get_to_device_events(&sender_user, &sender_device)?,
}, },
@ -311,16 +317,18 @@ pub(crate) async fn sync_events_route(
Ok(response) Ok(response)
} }
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(user_id = %sender_user, room_id = %room_id), name = "left_room")] #[tracing::instrument(skip_all, fields(user_id = %sender_user, room_id = %room_id), name = "left_room")]
async fn handle_left_room( async fn handle_left_room(
since: u64, room_id: &RoomId, sender_user: &UserId, left_rooms: &mut BTreeMap<ruma::OwnedRoomId, LeftRoom>, services: &Services, since: u64, room_id: &RoomId, sender_user: &UserId,
next_batch_string: &str, full_state: bool, lazy_load_enabled: bool, left_rooms: &mut BTreeMap<ruma::OwnedRoomId, LeftRoom>, next_batch_string: &str, full_state: bool,
lazy_load_enabled: bool,
) -> Result<()> { ) -> Result<()> {
// Get and drop the lock to wait for remaining operations to finish // Get and drop the lock to wait for remaining operations to finish
let insert_lock = services().rooms.timeline.mutex_insert.lock(room_id).await; let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await;
drop(insert_lock); drop(insert_lock);
let left_count = services() let left_count = services
.rooms .rooms
.state_cache .state_cache
.get_left_count(room_id, sender_user)?; .get_left_count(room_id, sender_user)?;
@ -330,11 +338,11 @@ async fn handle_left_room(
return Ok(()); return Ok(());
} }
if !services().rooms.metadata.exists(room_id)? { if !services.rooms.metadata.exists(room_id)? {
// This is just a rejected invite, not a room we know // This is just a rejected invite, not a room we know
// Insert a leave event anyways // Insert a leave event anyways
let event = PduEvent { let event = PduEvent {
event_id: EventId::new(services().globals.server_name()).into(), event_id: EventId::new(services.globals.server_name()).into(),
sender: sender_user.to_owned(), sender: sender_user.to_owned(),
origin: None, origin: None,
origin_server_ts: utils::millis_since_unix_epoch() origin_server_ts: utils::millis_since_unix_epoch()
@ -367,7 +375,7 @@ async fn handle_left_room(
prev_batch: Some(next_batch_string.to_owned()), prev_batch: Some(next_batch_string.to_owned()),
events: Vec::new(), events: Vec::new(),
}, },
state: State { state: RoomState {
events: vec![event.to_sync_state_event()], events: vec![event.to_sync_state_event()],
}, },
}, },
@ -377,27 +385,27 @@ async fn handle_left_room(
let mut left_state_events = Vec::new(); let mut left_state_events = Vec::new();
let since_shortstatehash = services() let since_shortstatehash = services
.rooms .rooms
.user .user
.get_token_shortstatehash(room_id, since)?; .get_token_shortstatehash(room_id, since)?;
let since_state_ids = match since_shortstatehash { let since_state_ids = match since_shortstatehash {
Some(s) => services().rooms.state_accessor.state_full_ids(s).await?, Some(s) => services.rooms.state_accessor.state_full_ids(s).await?,
None => HashMap::new(), None => HashMap::new(),
}; };
let Some(left_event_id) = services().rooms.state_accessor.room_state_get_id( let Some(left_event_id) =
room_id, services
&StateEventType::RoomMember, .rooms
sender_user.as_str(), .state_accessor
)? .room_state_get_id(room_id, &StateEventType::RoomMember, sender_user.as_str())?
else { else {
error!("Left room but no left state event"); error!("Left room but no left state event");
return Ok(()); return Ok(());
}; };
let Some(left_shortstatehash) = services() let Some(left_shortstatehash) = services
.rooms .rooms
.state_accessor .state_accessor
.pdu_shortstatehash(&left_event_id)? .pdu_shortstatehash(&left_event_id)?
@ -406,13 +414,13 @@ async fn handle_left_room(
return Ok(()); return Ok(());
}; };
let mut left_state_ids = services() let mut left_state_ids = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(left_shortstatehash) .state_full_ids(left_shortstatehash)
.await?; .await?;
let leave_shortstatekey = services() let leave_shortstatekey = services
.rooms .rooms
.short .short
.get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?;
@ -422,7 +430,7 @@ async fn handle_left_room(
let mut i: u8 = 0; let mut i: u8 = 0;
for (key, id) in left_state_ids { for (key, id) in left_state_ids {
if full_state || since_state_ids.get(&key) != Some(&id) { if full_state || since_state_ids.get(&key) != Some(&id) {
let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key)?;
if !lazy_load_enabled if !lazy_load_enabled
|| event_type != StateEventType::RoomMember || event_type != StateEventType::RoomMember
@ -430,7 +438,7 @@ async fn handle_left_room(
// TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565
|| (cfg!(feature = "element_hacks") && *sender_user == state_key) || (cfg!(feature = "element_hacks") && *sender_user == state_key)
{ {
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
continue; continue;
}; };
@ -456,7 +464,7 @@ async fn handle_left_room(
prev_batch: Some(next_batch_string.to_owned()), prev_batch: Some(next_batch_string.to_owned()),
events: Vec::new(), events: Vec::new(),
}, },
state: State { state: RoomState {
events: left_state_events, events: left_state_events,
}, },
}, },
@ -465,13 +473,13 @@ async fn handle_left_room(
} }
async fn process_presence_updates( async fn process_presence_updates(
presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, since: u64, syncing_user: &UserId, services: &Services, presence_updates: &mut HashMap<OwnedUserId, PresenceEvent>, since: u64, syncing_user: &UserId,
) -> Result<()> { ) -> Result<()> {
use crate::service::presence::Presence; use crate::service::presence::Presence;
// Take presence updates // Take presence updates
for (user_id, _, presence_bytes) in services().presence.presence_since(since) { for (user_id, _, presence_bytes) in services.presence.presence_since(since) {
if !services() if !services
.rooms .rooms
.state_cache .state_cache
.user_sees_user(syncing_user, &user_id)? .user_sees_user(syncing_user, &user_id)?
@ -513,19 +521,20 @@ async fn process_presence_updates(
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
async fn load_joined_room( async fn load_joined_room(
sender_user: &UserId, sender_device: &DeviceId, room_id: &RoomId, since: u64, sincecount: PduCount, services: &Services, sender_user: &UserId, sender_device: &DeviceId, room_id: &RoomId, since: u64,
next_batch: u64, next_batchcount: PduCount, lazy_load_enabled: bool, lazy_load_send_redundant: bool, sincecount: PduCount, next_batch: u64, next_batchcount: PduCount, lazy_load_enabled: bool,
full_state: bool, device_list_updates: &mut HashSet<OwnedUserId>, left_encrypted_users: &mut HashSet<OwnedUserId>, lazy_load_send_redundant: bool, full_state: bool, device_list_updates: &mut HashSet<OwnedUserId>,
left_encrypted_users: &mut HashSet<OwnedUserId>,
) -> Result<JoinedRoom> { ) -> Result<JoinedRoom> {
// Get and drop the lock to wait for remaining operations to finish // Get and drop the lock to wait for remaining operations to finish
// This will make sure the we have all events until next_batch // This will make sure the we have all events until next_batch
let insert_lock = services().rooms.timeline.mutex_insert.lock(room_id).await; let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await;
drop(insert_lock); drop(insert_lock);
let (timeline_pdus, limited) = load_timeline(sender_user, room_id, sincecount, 10)?; let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10)?;
let send_notification_counts = !timeline_pdus.is_empty() let send_notification_counts = !timeline_pdus.is_empty()
|| services() || services
.rooms .rooms
.user .user
.last_notification_read(sender_user, room_id)? .last_notification_read(sender_user, room_id)?
@ -536,7 +545,7 @@ async fn load_joined_room(
timeline_users.insert(event.sender.as_str().to_owned()); timeline_users.insert(event.sender.as_str().to_owned());
} }
services() services
.rooms .rooms
.lazy_loading .lazy_loading
.lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount) .lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount)
@ -544,11 +553,11 @@ async fn load_joined_room(
// Database queries: // Database queries:
let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else {
return Err!(Database(error!("Room {room_id} has no state"))); return Err!(Database(error!("Room {room_id} has no state")));
}; };
let since_shortstatehash = services() let since_shortstatehash = services
.rooms .rooms
.user .user
.get_token_shortstatehash(room_id, since)?; .get_token_shortstatehash(room_id, since)?;
@ -560,12 +569,12 @@ async fn load_joined_room(
} else { } else {
// Calculates joined_member_count, invited_member_count and heroes // Calculates joined_member_count, invited_member_count and heroes
let calculate_counts = || { let calculate_counts = || {
let joined_member_count = services() let joined_member_count = services
.rooms .rooms
.state_cache .state_cache
.room_joined_count(room_id)? .room_joined_count(room_id)?
.unwrap_or(0); .unwrap_or(0);
let invited_member_count = services() let invited_member_count = services
.rooms .rooms
.state_cache .state_cache
.room_invited_count(room_id)? .room_invited_count(room_id)?
@ -578,7 +587,7 @@ async fn load_joined_room(
// Go through all PDUs and for each member event, check if the user is still // Go through all PDUs and for each member event, check if the user is still
// joined or invited until we have 5 or we reach the end // joined or invited until we have 5 or we reach the end
for hero in services() for hero in services
.rooms .rooms
.timeline .timeline
.all_pdus(sender_user, room_id)? .all_pdus(sender_user, room_id)?
@ -594,8 +603,8 @@ async fn load_joined_room(
// The membership was and still is invite or join // The membership was and still is invite or join
if matches!(content.membership, MembershipState::Join | MembershipState::Invite) if matches!(content.membership, MembershipState::Join | MembershipState::Invite)
&& (services().rooms.state_cache.is_joined(&user_id, room_id)? && (services.rooms.state_cache.is_joined(&user_id, room_id)?
|| services().rooms.state_cache.is_invited(&user_id, room_id)?) || services.rooms.state_cache.is_invited(&user_id, room_id)?)
{ {
Ok::<_, Error>(Some(user_id)) Ok::<_, Error>(Some(user_id))
} else { } else {
@ -622,7 +631,7 @@ async fn load_joined_room(
let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash
.and_then(|shortstatehash| { .and_then(|shortstatehash| {
services() services
.rooms .rooms
.state_accessor .state_accessor
.state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str())
@ -643,7 +652,7 @@ async fn load_joined_room(
let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; let (joined_member_count, invited_member_count, heroes) = calculate_counts()?;
let current_state_ids = services() let current_state_ids = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(current_shortstatehash) .state_full_ids(current_shortstatehash)
@ -654,13 +663,13 @@ async fn load_joined_room(
let mut i: u8 = 0; let mut i: u8 = 0;
for (shortstatekey, id) in current_state_ids { for (shortstatekey, id) in current_state_ids {
let (event_type, state_key) = services() let (event_type, state_key) = services
.rooms .rooms
.short .short
.get_statekey_from_short(shortstatekey)?; .get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember { if event_type != StateEventType::RoomMember {
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
continue; continue;
}; };
@ -676,7 +685,7 @@ async fn load_joined_room(
// TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565
|| (cfg!(feature = "element_hacks") && *sender_user == state_key) || (cfg!(feature = "element_hacks") && *sender_user == state_key)
{ {
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
continue; continue;
}; };
@ -695,14 +704,14 @@ async fn load_joined_room(
} }
// Reset lazy loading because this is an initial sync // Reset lazy loading because this is an initial sync
services() services
.rooms .rooms
.lazy_loading .lazy_loading
.lazy_load_reset(sender_user, sender_device, room_id)?; .lazy_load_reset(sender_user, sender_device, room_id)?;
// The state_events above should contain all timeline_users, let's mark them as // The state_events above should contain all timeline_users, let's mark them as
// lazy loaded. // lazy loaded.
services() services
.rooms .rooms
.lazy_loading .lazy_loading
.lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount)
@ -716,12 +725,12 @@ async fn load_joined_room(
let mut delta_state_events = Vec::new(); let mut delta_state_events = Vec::new();
if since_shortstatehash != current_shortstatehash { if since_shortstatehash != current_shortstatehash {
let current_state_ids = services() let current_state_ids = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(current_shortstatehash) .state_full_ids(current_shortstatehash)
.await?; .await?;
let since_state_ids = services() let since_state_ids = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(since_shortstatehash) .state_full_ids(since_shortstatehash)
@ -729,7 +738,7 @@ async fn load_joined_room(
for (key, id) in current_state_ids { for (key, id) in current_state_ids {
if full_state || since_state_ids.get(&key) != Some(&id) { if full_state || since_state_ids.get(&key) != Some(&id) {
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
continue; continue;
}; };
@ -740,13 +749,13 @@ async fn load_joined_room(
} }
} }
let encrypted_room = services() let encrypted_room = services
.rooms .rooms
.state_accessor .state_accessor
.state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")?
.is_some(); .is_some();
let since_encryption = services().rooms.state_accessor.state_get( let since_encryption = services.rooms.state_accessor.state_get(
since_shortstatehash, since_shortstatehash,
&StateEventType::RoomEncryption, &StateEventType::RoomEncryption,
"", "",
@ -781,7 +790,7 @@ async fn load_joined_room(
match new_membership { match new_membership {
MembershipState::Join => { MembershipState::Join => {
// A new user joined an encrypted room // A new user joined an encrypted room
if !share_encrypted_room(sender_user, &user_id, room_id)? { if !share_encrypted_room(services, sender_user, &user_id, room_id)? {
device_list_updates.insert(user_id); device_list_updates.insert(user_id);
} }
}, },
@ -798,7 +807,7 @@ async fn load_joined_room(
if joined_since_last_sync && encrypted_room || new_encrypted_room { if joined_since_last_sync && encrypted_room || new_encrypted_room {
// If the user is in a new encrypted room, give them all joined users // If the user is in a new encrypted room, give them all joined users
device_list_updates.extend( device_list_updates.extend(
services() services
.rooms .rooms
.state_cache .state_cache
.room_members(room_id) .room_members(room_id)
@ -810,7 +819,7 @@ async fn load_joined_room(
.filter(|user_id| { .filter(|user_id| {
// Only send keys if the sender doesn't share an encrypted room with the target // Only send keys if the sender doesn't share an encrypted room with the target
// already // already
!share_encrypted_room(sender_user, user_id, room_id).unwrap_or(false) !share_encrypted_room(services, sender_user, user_id, room_id).unwrap_or(false)
}), }),
); );
} }
@ -848,14 +857,14 @@ async fn load_joined_room(
continue; continue;
} }
if !services().rooms.lazy_loading.lazy_load_was_sent_before( if !services.rooms.lazy_loading.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
room_id, room_id,
&event.sender, &event.sender,
)? || lazy_load_send_redundant )? || lazy_load_send_redundant
{ {
if let Some(member_event) = services().rooms.state_accessor.room_state_get( if let Some(member_event) = services.rooms.state_accessor.room_state_get(
room_id, room_id,
&StateEventType::RoomMember, &StateEventType::RoomMember,
event.sender.as_str(), event.sender.as_str(),
@ -866,7 +875,7 @@ async fn load_joined_room(
} }
} }
services() services
.rooms .rooms
.lazy_loading .lazy_loading
.lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount) .lazy_load_mark_sent(sender_user, sender_device, room_id, lazy_loaded, next_batchcount)
@ -884,7 +893,7 @@ async fn load_joined_room(
// Look for device list updates in this room // Look for device list updates in this room
device_list_updates.extend( device_list_updates.extend(
services() services
.users .users
.keys_changed(room_id.as_ref(), since, None) .keys_changed(room_id.as_ref(), since, None)
.filter_map(Result::ok), .filter_map(Result::ok),
@ -892,7 +901,7 @@ async fn load_joined_room(
let notification_count = if send_notification_counts { let notification_count = if send_notification_counts {
Some( Some(
services() services
.rooms .rooms
.user .user
.notification_count(sender_user, room_id)? .notification_count(sender_user, room_id)?
@ -905,7 +914,7 @@ async fn load_joined_room(
let highlight_count = if send_notification_counts { let highlight_count = if send_notification_counts {
Some( Some(
services() services
.rooms .rooms
.user .user
.highlight_count(sender_user, room_id)? .highlight_count(sender_user, room_id)?
@ -933,7 +942,7 @@ async fn load_joined_room(
.map(|(_, pdu)| pdu.to_sync_room_event()) .map(|(_, pdu)| pdu.to_sync_room_event())
.collect(); .collect();
let mut edus: Vec<_> = services() let mut edus: Vec<_> = services
.rooms .rooms
.read_receipt .read_receipt
.readreceipts_since(room_id, since) .readreceipts_since(room_id, since)
@ -941,10 +950,10 @@ async fn load_joined_room(
.map(|(_, _, v)| v) .map(|(_, _, v)| v)
.collect(); .collect();
if services().rooms.typing.last_typing_update(room_id).await? > since { if services.rooms.typing.last_typing_update(room_id).await? > since {
edus.push( edus.push(
serde_json::from_str( serde_json::from_str(
&serde_json::to_string(&services().rooms.typing.typings_all(room_id).await?) &serde_json::to_string(&services.rooms.typing.typings_all(room_id).await?)
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
) )
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
@ -953,14 +962,14 @@ async fn load_joined_room(
// Save the state after this sync so we can send the correct state diff next // Save the state after this sync so we can send the correct state diff next
// sync // sync
services() services
.rooms .rooms
.user .user
.associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?; .associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?;
Ok(JoinedRoom { Ok(JoinedRoom {
account_data: RoomAccountData { account_data: RoomAccountData {
events: services() events: services
.account_data .account_data
.changes_since(Some(room_id), sender_user, since)? .changes_since(Some(room_id), sender_user, since)?
.into_iter() .into_iter()
@ -985,7 +994,7 @@ async fn load_joined_room(
prev_batch, prev_batch,
events: room_events, events: room_events,
}, },
state: State { state: RoomState {
events: state_events events: state_events
.iter() .iter()
.map(|pdu| pdu.to_sync_state_event()) .map(|pdu| pdu.to_sync_state_event())
@ -999,16 +1008,16 @@ async fn load_joined_room(
} }
fn load_timeline( fn load_timeline(
sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64,
) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> {
let timeline_pdus; let timeline_pdus;
let limited = if services() let limited = if services
.rooms .rooms
.timeline .timeline
.last_timeline_count(sender_user, room_id)? .last_timeline_count(sender_user, room_id)?
> roomsincecount > roomsincecount
{ {
let mut non_timeline_pdus = services() let mut non_timeline_pdus = services
.rooms .rooms
.timeline .timeline
.pdus_until(sender_user, room_id, PduCount::max())? .pdus_until(sender_user, room_id, PduCount::max())?
@ -1040,8 +1049,10 @@ fn load_timeline(
Ok((timeline_pdus, limited)) Ok((timeline_pdus, limited))
} }
fn share_encrypted_room(sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId) -> Result<bool> { fn share_encrypted_room(
Ok(services() services: &Services, sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId,
) -> Result<bool> {
Ok(services
.rooms .rooms
.user .user
.get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])?
@ -1049,7 +1060,7 @@ fn share_encrypted_room(sender_user: &UserId, user_id: &UserId, ignore_room: &Ro
.filter(|room_id| room_id != ignore_room) .filter(|room_id| room_id != ignore_room)
.filter_map(|other_room_id| { .filter_map(|other_room_id| {
Some( Some(
services() services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "")
@ -1064,15 +1075,15 @@ fn share_encrypted_room(sender_user: &UserId, user_id: &UserId, ignore_room: &Ro
/// ///
/// Sliding Sync endpoint (future endpoint: `/_matrix/client/v4/sync`) /// Sliding Sync endpoint (future endpoint: `/_matrix/client/v4/sync`)
pub(crate) async fn sync_events_v4_route( pub(crate) async fn sync_events_v4_route(
body: Ruma<sync_events::v4::Request>, State(services): State<crate::State>, body: Ruma<sync_events::v4::Request>,
) -> Result<sync_events::v4::Response, RumaResponse<UiaaResponse>> { ) -> Result<sync_events::v4::Response, RumaResponse<UiaaResponse>> {
let sender_user = body.sender_user.expect("user is authenticated"); let sender_user = body.sender_user.expect("user is authenticated");
let sender_device = body.sender_device.expect("user is authenticated"); let sender_device = body.sender_device.expect("user is authenticated");
let mut body = body.body; let mut body = body.body;
// Setup watchers, so if there's no response, we can wait for them // Setup watchers, so if there's no response, we can wait for them
let watcher = services().globals.watch(&sender_user, &sender_device); let watcher = services.globals.watch(&sender_user, &sender_device);
let next_batch = services().globals.next_count()?; let next_batch = services.globals.next_count()?;
let globalsince = body let globalsince = body
.pos .pos
@ -1082,21 +1093,19 @@ pub(crate) async fn sync_events_v4_route(
if globalsince == 0 { if globalsince == 0 {
if let Some(conn_id) = &body.conn_id { if let Some(conn_id) = &body.conn_id {
services().users.forget_sync_request_connection( services
sender_user.clone(), .users
sender_device.clone(), .forget_sync_request_connection(sender_user.clone(), sender_device.clone(), conn_id.clone());
conn_id.clone(),
);
} }
} }
// Get sticky parameters from cache // Get sticky parameters from cache
let known_rooms = let known_rooms =
services() services
.users .users
.update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); .update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body);
let all_joined_rooms = services() let all_joined_rooms = services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&sender_user) .rooms_joined(&sender_user)
@ -1104,7 +1113,7 @@ pub(crate) async fn sync_events_v4_route(
.collect::<Vec<_>>(); .collect::<Vec<_>>();
if body.extensions.to_device.enabled.unwrap_or(false) { if body.extensions.to_device.enabled.unwrap_or(false) {
services() services
.users .users
.remove_to_device_events(&sender_user, &sender_device, globalsince)?; .remove_to_device_events(&sender_user, &sender_device, globalsince)?;
} }
@ -1116,26 +1125,26 @@ pub(crate) async fn sync_events_v4_route(
if body.extensions.e2ee.enabled.unwrap_or(false) { if body.extensions.e2ee.enabled.unwrap_or(false) {
// Look for device list updates of this account // Look for device list updates of this account
device_list_changes.extend( device_list_changes.extend(
services() services
.users .users
.keys_changed(sender_user.as_ref(), globalsince, None) .keys_changed(sender_user.as_ref(), globalsince, None)
.filter_map(Result::ok), .filter_map(Result::ok),
); );
for room_id in &all_joined_rooms { for room_id in &all_joined_rooms {
let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? else { let Some(current_shortstatehash) = services.rooms.state.get_room_shortstatehash(room_id)? else {
error!("Room {} has no state", room_id); error!("Room {} has no state", room_id);
continue; continue;
}; };
let since_shortstatehash = services() let since_shortstatehash = services
.rooms .rooms
.user .user
.get_token_shortstatehash(room_id, globalsince)?; .get_token_shortstatehash(room_id, globalsince)?;
let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash let since_sender_member: Option<RoomMemberEventContent> = since_shortstatehash
.and_then(|shortstatehash| { .and_then(|shortstatehash| {
services() services
.rooms .rooms
.state_accessor .state_accessor
.state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str())
@ -1148,7 +1157,7 @@ pub(crate) async fn sync_events_v4_route(
.ok() .ok()
}); });
let encrypted_room = services() let encrypted_room = services
.rooms .rooms
.state_accessor .state_accessor
.state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")?
@ -1160,7 +1169,7 @@ pub(crate) async fn sync_events_v4_route(
continue; continue;
} }
let since_encryption = services().rooms.state_accessor.state_get( let since_encryption = services.rooms.state_accessor.state_get(
since_shortstatehash, since_shortstatehash,
&StateEventType::RoomEncryption, &StateEventType::RoomEncryption,
"", "",
@ -1171,12 +1180,12 @@ pub(crate) async fn sync_events_v4_route(
let new_encrypted_room = encrypted_room && since_encryption.is_none(); let new_encrypted_room = encrypted_room && since_encryption.is_none();
if encrypted_room { if encrypted_room {
let current_state_ids = services() let current_state_ids = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(current_shortstatehash) .state_full_ids(current_shortstatehash)
.await?; .await?;
let since_state_ids = services() let since_state_ids = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(since_shortstatehash) .state_full_ids(since_shortstatehash)
@ -1184,7 +1193,7 @@ pub(crate) async fn sync_events_v4_route(
for (key, id) in current_state_ids { for (key, id) in current_state_ids {
if since_state_ids.get(&key) != Some(&id) { if since_state_ids.get(&key) != Some(&id) {
let Some(pdu) = services().rooms.timeline.get_pdu(&id)? else { let Some(pdu) = services.rooms.timeline.get_pdu(&id)? else {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
continue; continue;
}; };
@ -1205,7 +1214,7 @@ pub(crate) async fn sync_events_v4_route(
match new_membership { match new_membership {
MembershipState::Join => { MembershipState::Join => {
// A new user joined an encrypted room // A new user joined an encrypted room
if !share_encrypted_room(&sender_user, &user_id, room_id)? { if !share_encrypted_room(services, &sender_user, &user_id, room_id)? {
device_list_changes.insert(user_id); device_list_changes.insert(user_id);
} }
}, },
@ -1222,7 +1231,7 @@ pub(crate) async fn sync_events_v4_route(
if joined_since_last_sync || new_encrypted_room { if joined_since_last_sync || new_encrypted_room {
// If the user is in a new encrypted room, give them all joined users // If the user is in a new encrypted room, give them all joined users
device_list_changes.extend( device_list_changes.extend(
services() services
.rooms .rooms
.state_cache .state_cache
.room_members(room_id) .room_members(room_id)
@ -1234,7 +1243,7 @@ pub(crate) async fn sync_events_v4_route(
.filter(|user_id| { .filter(|user_id| {
// Only send keys if the sender doesn't share an encrypted room with the target // Only send keys if the sender doesn't share an encrypted room with the target
// already // already
!share_encrypted_room(&sender_user, user_id, room_id).unwrap_or(false) !share_encrypted_room(services, &sender_user, user_id, room_id).unwrap_or(false)
}), }),
); );
} }
@ -1242,21 +1251,21 @@ pub(crate) async fn sync_events_v4_route(
} }
// Look for device list updates in this room // Look for device list updates in this room
device_list_changes.extend( device_list_changes.extend(
services() services
.users .users
.keys_changed(room_id.as_ref(), globalsince, None) .keys_changed(room_id.as_ref(), globalsince, None)
.filter_map(Result::ok), .filter_map(Result::ok),
); );
} }
for user_id in left_encrypted_users { for user_id in left_encrypted_users {
let dont_share_encrypted_room = services() let dont_share_encrypted_room = services
.rooms .rooms
.user .user
.get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])?
.filter_map(Result::ok) .filter_map(Result::ok)
.filter_map(|other_room_id| { .filter_map(|other_room_id| {
Some( Some(
services() services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "")
@ -1336,7 +1345,7 @@ pub(crate) async fn sync_events_v4_route(
); );
if let Some(conn_id) = &body.conn_id { if let Some(conn_id) = &body.conn_id {
services().users.update_sync_known_rooms( services.users.update_sync_known_rooms(
sender_user.clone(), sender_user.clone(),
sender_device.clone(), sender_device.clone(),
conn_id.clone(), conn_id.clone(),
@ -1349,7 +1358,7 @@ pub(crate) async fn sync_events_v4_route(
let mut known_subscription_rooms = BTreeSet::new(); let mut known_subscription_rooms = BTreeSet::new();
for (room_id, room) in &body.room_subscriptions { for (room_id, room) in &body.room_subscriptions {
if !services().rooms.metadata.exists(room_id)? { if !services.rooms.metadata.exists(room_id)? {
continue; continue;
} }
let todo_room = todo_rooms let todo_room = todo_rooms
@ -1375,7 +1384,7 @@ pub(crate) async fn sync_events_v4_route(
} }
if let Some(conn_id) = &body.conn_id { if let Some(conn_id) = &body.conn_id {
services().users.update_sync_known_rooms( services.users.update_sync_known_rooms(
sender_user.clone(), sender_user.clone(),
sender_device.clone(), sender_device.clone(),
conn_id.clone(), conn_id.clone(),
@ -1386,7 +1395,7 @@ pub(crate) async fn sync_events_v4_route(
} }
if let Some(conn_id) = &body.conn_id { if let Some(conn_id) = &body.conn_id {
services().users.update_sync_subscriptions( services.users.update_sync_subscriptions(
sender_user.clone(), sender_user.clone(),
sender_device.clone(), sender_device.clone(),
conn_id.clone(), conn_id.clone(),
@ -1398,7 +1407,7 @@ pub(crate) async fn sync_events_v4_route(
for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms { for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms {
let roomsincecount = PduCount::Normal(*roomsince); let roomsincecount = PduCount::Normal(*roomsince);
let (timeline_pdus, limited) = load_timeline(&sender_user, room_id, roomsincecount, *timeline_limit)?; let (timeline_pdus, limited) = load_timeline(services, &sender_user, room_id, roomsincecount, *timeline_limit)?;
if roomsince != &0 && timeline_pdus.is_empty() { if roomsince != &0 && timeline_pdus.is_empty() {
continue; continue;
@ -1431,7 +1440,7 @@ pub(crate) async fn sync_events_v4_route(
let required_state = required_state_request let required_state = required_state_request
.iter() .iter()
.map(|state| { .map(|state| {
services() services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &state.0, &state.1) .room_state_get(room_id, &state.0, &state.1)
@ -1442,7 +1451,7 @@ pub(crate) async fn sync_events_v4_route(
.collect(); .collect();
// Heroes // Heroes
let heroes = services() let heroes = services
.rooms .rooms
.state_cache .state_cache
.room_members(room_id) .room_members(room_id)
@ -1450,7 +1459,7 @@ pub(crate) async fn sync_events_v4_route(
.filter(|member| member != &sender_user) .filter(|member| member != &sender_user)
.map(|member| { .map(|member| {
Ok::<_, Error>( Ok::<_, Error>(
services() services
.rooms .rooms
.state_accessor .state_accessor
.get_member(room_id, &member)? .get_member(room_id, &member)?
@ -1491,11 +1500,11 @@ pub(crate) async fn sync_events_v4_route(
rooms.insert( rooms.insert(
room_id.clone(), room_id.clone(),
sync_events::v4::SlidingSyncRoom { sync_events::v4::SlidingSyncRoom {
name: services().rooms.state_accessor.get_name(room_id)?.or(name), name: services.rooms.state_accessor.get_name(room_id)?.or(name),
avatar: if let Some(heroes_avatar) = heroes_avatar { avatar: if let Some(heroes_avatar) = heroes_avatar {
ruma::JsOption::Some(heroes_avatar) ruma::JsOption::Some(heroes_avatar)
} else { } else {
match services().rooms.state_accessor.get_avatar(room_id)? { match services.rooms.state_accessor.get_avatar(room_id)? {
ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url), ruma::JsOption::Some(avatar) => ruma::JsOption::from_option(avatar.url),
ruma::JsOption::Null => ruma::JsOption::Null, ruma::JsOption::Null => ruma::JsOption::Null,
ruma::JsOption::Undefined => ruma::JsOption::Undefined, ruma::JsOption::Undefined => ruma::JsOption::Undefined,
@ -1506,7 +1515,7 @@ pub(crate) async fn sync_events_v4_route(
invite_state: None, invite_state: None,
unread_notifications: UnreadNotificationsCount { unread_notifications: UnreadNotificationsCount {
highlight_count: Some( highlight_count: Some(
services() services
.rooms .rooms
.user .user
.highlight_count(&sender_user, room_id)? .highlight_count(&sender_user, room_id)?
@ -1514,7 +1523,7 @@ pub(crate) async fn sync_events_v4_route(
.expect("notification count can't go that high"), .expect("notification count can't go that high"),
), ),
notification_count: Some( notification_count: Some(
services() services
.rooms .rooms
.user .user
.notification_count(&sender_user, room_id)? .notification_count(&sender_user, room_id)?
@ -1527,7 +1536,7 @@ pub(crate) async fn sync_events_v4_route(
prev_batch, prev_batch,
limited, limited,
joined_count: Some( joined_count: Some(
services() services
.rooms .rooms
.state_cache .state_cache
.room_joined_count(room_id)? .room_joined_count(room_id)?
@ -1536,7 +1545,7 @@ pub(crate) async fn sync_events_v4_route(
.unwrap_or_else(|_| uint!(0)), .unwrap_or_else(|_| uint!(0)),
), ),
invited_count: Some( invited_count: Some(
services() services
.rooms .rooms
.state_cache .state_cache
.room_invited_count(room_id)? .room_invited_count(room_id)?
@ -1571,7 +1580,7 @@ pub(crate) async fn sync_events_v4_route(
extensions: sync_events::v4::Extensions { extensions: sync_events::v4::Extensions {
to_device: if body.extensions.to_device.enabled.unwrap_or(false) { to_device: if body.extensions.to_device.enabled.unwrap_or(false) {
Some(sync_events::v4::ToDevice { Some(sync_events::v4::ToDevice {
events: services() events: services
.users .users
.get_to_device_events(&sender_user, &sender_device)?, .get_to_device_events(&sender_user, &sender_device)?,
next_batch: next_batch.to_string(), next_batch: next_batch.to_string(),
@ -1584,7 +1593,7 @@ pub(crate) async fn sync_events_v4_route(
changed: device_list_changes.into_iter().collect(), changed: device_list_changes.into_iter().collect(),
left: device_list_left.into_iter().collect(), left: device_list_left.into_iter().collect(),
}, },
device_one_time_keys_count: services() device_one_time_keys_count: services
.users .users
.count_one_time_keys(&sender_user, &sender_device)?, .count_one_time_keys(&sender_user, &sender_device)?,
// Fallback keys are not yet supported // Fallback keys are not yet supported
@ -1592,7 +1601,7 @@ pub(crate) async fn sync_events_v4_route(
}, },
account_data: sync_events::v4::AccountData { account_data: sync_events::v4::AccountData {
global: if body.extensions.account_data.enabled.unwrap_or(false) { global: if body.extensions.account_data.enabled.unwrap_or(false) {
services() services
.account_data .account_data
.changes_since(None, &sender_user, globalsince)? .changes_since(None, &sender_user, globalsince)?
.into_iter() .into_iter()

View file

@ -1,5 +1,6 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State;
use ruma::{ use ruma::{
api::client::tag::{create_tag, delete_tag, get_tags}, api::client::tag::{create_tag, delete_tag, get_tags},
events::{ events::{
@ -8,17 +9,19 @@ use ruma::{
}, },
}; };
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` /// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}`
/// ///
/// Adds a tag to the room. /// Adds a tag to the room.
/// ///
/// - Inserts the tag into the tag event of the room account data. /// - Inserts the tag into the tag event of the room account data.
pub(crate) async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Result<create_tag::v3::Response> { pub(crate) async fn update_tag_route(
State(services): State<crate::State>, body: Ruma<create_tag::v3::Request>,
) -> Result<create_tag::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services() let event = services
.account_data .account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
@ -38,7 +41,7 @@ pub(crate) async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Res
.tags .tags
.insert(body.tag.clone().into(), body.tag_info.clone()); .insert(body.tag.clone().into(), body.tag_info.clone());
services().account_data.update( services.account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
@ -53,10 +56,12 @@ pub(crate) async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Res
/// Deletes a tag from the room. /// Deletes a tag from the room.
/// ///
/// - Removes the tag from the tag event of the room account data. /// - Removes the tag from the tag event of the room account data.
pub(crate) async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Result<delete_tag::v3::Response> { pub(crate) async fn delete_tag_route(
State(services): State<crate::State>, body: Ruma<delete_tag::v3::Request>,
) -> Result<delete_tag::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services() let event = services
.account_data .account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
@ -73,7 +78,7 @@ pub(crate) async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Res
tags_event.content.tags.remove(&body.tag.clone().into()); tags_event.content.tags.remove(&body.tag.clone().into());
services().account_data.update( services.account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
@ -88,10 +93,12 @@ pub(crate) async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Res
/// Returns tags on the room. /// Returns tags on the room.
/// ///
/// - Gets the tag event of the room account data. /// - Gets the tag event of the room account data.
pub(crate) async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> { pub(crate) async fn get_tags_route(
State(services): State<crate::State>, body: Ruma<get_tags::v3::Request>,
) -> Result<get_tags::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event = services() let event = services
.account_data .account_data
.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; .get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;

View file

@ -1,12 +1,15 @@
use axum::extract::State;
use ruma::{ use ruma::{
api::client::{error::ErrorKind, threads::get_threads}, api::client::{error::ErrorKind, threads::get_threads},
uint, uint,
}; };
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/threads` /// # `GET /_matrix/client/r0/rooms/{roomId}/threads`
pub(crate) async fn get_threads_route(body: Ruma<get_threads::v1::Request>) -> Result<get_threads::v1::Response> { pub(crate) async fn get_threads_route(
State(services): State<crate::State>, body: Ruma<get_threads::v1::Request>,
) -> Result<get_threads::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
@ -24,14 +27,14 @@ pub(crate) async fn get_threads_route(body: Ruma<get_threads::v1::Request>) -> R
u64::MAX u64::MAX
}; };
let threads = services() let threads = services
.rooms .rooms
.threads .threads
.threads_until(sender_user, &body.room_id, from, &body.include)? .threads_until(sender_user, &body.room_id, from, &body.include)?
.take(limit) .take(limit)
.filter_map(Result::ok) .filter_map(Result::ok)
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services
.rooms .rooms
.state_accessor .state_accessor
.user_can_see_event(sender_user, &body.room_id, &pdu.event_id) .user_can_see_event(sender_user, &body.room_id, &pdu.event_id)

View file

@ -1,5 +1,6 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State;
use ruma::{ use ruma::{
api::{ api::{
client::{error::ErrorKind, to_device::send_event_to_device}, client::{error::ErrorKind, to_device::send_event_to_device},
@ -8,19 +9,19 @@ use ruma::{
to_device::DeviceIdOrAllDevices, to_device::DeviceIdOrAllDevices,
}; };
use crate::{services, user_is_local, Error, Result, Ruma}; use crate::{user_is_local, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}` /// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}`
/// ///
/// Send a to-device event to a set of client devices. /// Send a to-device event to a set of client devices.
pub(crate) async fn send_event_to_device_route( pub(crate) async fn send_event_to_device_route(
body: Ruma<send_event_to_device::v3::Request>, State(services): State<crate::State>, body: Ruma<send_event_to_device::v3::Request>,
) -> Result<send_event_to_device::v3::Response> { ) -> Result<send_event_to_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_deref(); let sender_device = body.sender_device.as_deref();
// Check if this is a new transaction id // Check if this is a new transaction id
if services() if services
.transaction_ids .transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)? .existing_txnid(sender_user, sender_device, &body.txn_id)?
.is_some() .is_some()
@ -35,9 +36,9 @@ pub(crate) async fn send_event_to_device_route(
map.insert(target_device_id_maybe.clone(), event.clone()); map.insert(target_device_id_maybe.clone(), event.clone());
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
messages.insert(target_user_id.clone(), map); messages.insert(target_user_id.clone(), map);
let count = services().globals.next_count()?; let count = services.globals.next_count()?;
services().sending.send_edu_server( services.sending.send_edu_server(
target_user_id.server_name(), target_user_id.server_name(),
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent { serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent {
sender: sender_user.clone(), sender: sender_user.clone(),
@ -53,7 +54,7 @@ pub(crate) async fn send_event_to_device_route(
match target_device_id_maybe { match target_device_id_maybe {
DeviceIdOrAllDevices::DeviceId(target_device_id) => { DeviceIdOrAllDevices::DeviceId(target_device_id) => {
services().users.add_to_device_event( services.users.add_to_device_event(
sender_user, sender_user,
target_user_id, target_user_id,
target_device_id, target_device_id,
@ -65,8 +66,8 @@ pub(crate) async fn send_event_to_device_route(
}, },
DeviceIdOrAllDevices::AllDevices => { DeviceIdOrAllDevices::AllDevices => {
for target_device_id in services().users.all_device_ids(target_user_id) { for target_device_id in services.users.all_device_ids(target_user_id) {
services().users.add_to_device_event( services.users.add_to_device_event(
sender_user, sender_user,
target_user_id, target_user_id,
&target_device_id?, &target_device_id?,
@ -82,7 +83,7 @@ pub(crate) async fn send_event_to_device_route(
} }
// Save transaction id with empty data // Save transaction id with empty data
services() services
.transaction_ids .transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?; .add_txnid(sender_user, sender_device, &body.txn_id, &[])?;

View file

@ -1,18 +1,19 @@
use axum::extract::State;
use ruma::api::client::{error::ErrorKind, typing::create_typing_event}; use ruma::api::client::{error::ErrorKind, typing::create_typing_event};
use crate::{services, utils, Error, Result, Ruma}; use crate::{utils, Error, Result, Ruma};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}`
/// ///
/// Sets the typing state of the sender user. /// Sets the typing state of the sender user.
pub(crate) async fn create_typing_event_route( pub(crate) async fn create_typing_event_route(
body: Ruma<create_typing_event::v3::Request>, State(services): State<crate::State>, body: Ruma<create_typing_event::v3::Request>,
) -> Result<create_typing_event::v3::Response> { ) -> Result<create_typing_event::v3::Response> {
use create_typing_event::v3::Typing; use create_typing_event::v3::Typing;
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !services() if !services
.rooms .rooms
.state_cache .state_cache
.is_joined(sender_user, &body.room_id)? .is_joined(sender_user, &body.room_id)?
@ -23,20 +24,20 @@ pub(crate) async fn create_typing_event_route(
if let Typing::Yes(duration) = body.state { if let Typing::Yes(duration) = body.state {
let duration = utils::clamp( let duration = utils::clamp(
duration.as_millis().try_into().unwrap_or(u64::MAX), duration.as_millis().try_into().unwrap_or(u64::MAX),
services() services
.globals .globals
.config .config
.typing_client_timeout_min_s .typing_client_timeout_min_s
.checked_mul(1000) .checked_mul(1000)
.unwrap(), .unwrap(),
services() services
.globals .globals
.config .config
.typing_client_timeout_max_s .typing_client_timeout_max_s
.checked_mul(1000) .checked_mul(1000)
.unwrap(), .unwrap(),
); );
services() services
.rooms .rooms
.typing .typing
.typing_add( .typing_add(
@ -48,7 +49,7 @@ pub(crate) async fn create_typing_event_route(
) )
.await?; .await?;
} else { } else {
services() services
.rooms .rooms
.typing .typing
.typing_remove(sender_user, &body.room_id) .typing_remove(sender_user, &body.room_id)

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::warn; use conduit::warn;
use ruma::{ use ruma::{
@ -6,7 +7,7 @@ use ruma::{
OwnedRoomId, OwnedRoomId,
}; };
use crate::{services, Error, Result, Ruma, RumaResponse}; use crate::{Error, Result, Ruma, RumaResponse};
/// # `GET /_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms` /// # `GET /_matrix/client/unstable/uk.half-shot.msc2666/user/mutual_rooms`
/// ///
@ -17,7 +18,8 @@ use crate::{services, Error, Result, Ruma, RumaResponse};
/// An implementation of [MSC2666](https://github.com/matrix-org/matrix-spec-proposals/pull/2666) /// An implementation of [MSC2666](https://github.com/matrix-org/matrix-spec-proposals/pull/2666)
#[tracing::instrument(skip_all, fields(%client), name = "mutual_rooms")] #[tracing::instrument(skip_all, fields(%client), name = "mutual_rooms")]
pub(crate) async fn get_mutual_rooms_route( pub(crate) async fn get_mutual_rooms_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<mutual_rooms::unstable::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<mutual_rooms::unstable::Request>,
) -> Result<mutual_rooms::unstable::Response> { ) -> Result<mutual_rooms::unstable::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -28,14 +30,14 @@ pub(crate) async fn get_mutual_rooms_route(
)); ));
} }
if !services().users.exists(&body.user_id)? { if !services.users.exists(&body.user_id)? {
return Ok(mutual_rooms::unstable::Response { return Ok(mutual_rooms::unstable::Response {
joined: vec![], joined: vec![],
next_batch_token: None, next_batch_token: None,
}); });
} }
let mutual_rooms: Vec<OwnedRoomId> = services() let mutual_rooms: Vec<OwnedRoomId> = services
.rooms .rooms
.user .user
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
@ -58,9 +60,10 @@ pub(crate) async fn get_mutual_rooms_route(
/// ///
/// An implementation of [MSC3266](https://github.com/matrix-org/matrix-spec-proposals/pull/3266) /// An implementation of [MSC3266](https://github.com/matrix-org/matrix-spec-proposals/pull/3266)
pub(crate) async fn get_room_summary_legacy( pub(crate) async fn get_room_summary_legacy(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_summary::msc3266::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_summary::msc3266::Request>,
) -> Result<RumaResponse<get_summary::msc3266::Response>> { ) -> Result<RumaResponse<get_summary::msc3266::Response>> {
get_room_summary(InsecureClientIp(client), body) get_room_summary(State(services), InsecureClientIp(client), body)
.await .await
.map(RumaResponse) .map(RumaResponse)
} }
@ -74,22 +77,19 @@ pub(crate) async fn get_room_summary_legacy(
/// An implementation of [MSC3266](https://github.com/matrix-org/matrix-spec-proposals/pull/3266) /// An implementation of [MSC3266](https://github.com/matrix-org/matrix-spec-proposals/pull/3266)
#[tracing::instrument(skip_all, fields(%client), name = "room_summary")] #[tracing::instrument(skip_all, fields(%client), name = "room_summary")]
pub(crate) async fn get_room_summary( pub(crate) async fn get_room_summary(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_summary::msc3266::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_summary::msc3266::Request>,
) -> Result<get_summary::msc3266::Response> { ) -> Result<get_summary::msc3266::Response> {
let sender_user = body.sender_user.as_ref(); let sender_user = body.sender_user.as_ref();
let room_id = services() let room_id = services.rooms.alias.resolve(&body.room_id_or_alias).await?;
.rooms
.alias
.resolve(&body.room_id_or_alias)
.await?;
if !services().rooms.metadata.exists(&room_id)? { if !services.rooms.metadata.exists(&room_id)? {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server"));
} }
if sender_user.is_none() if sender_user.is_none()
&& !services() && !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&room_id) .is_world_readable(&room_id)
@ -103,25 +103,25 @@ pub(crate) async fn get_room_summary(
Ok(get_summary::msc3266::Response { Ok(get_summary::msc3266::Response {
room_id: room_id.clone(), room_id: room_id.clone(),
canonical_alias: services() canonical_alias: services
.rooms .rooms
.state_accessor .state_accessor
.get_canonical_alias(&room_id) .get_canonical_alias(&room_id)
.unwrap_or(None), .unwrap_or(None),
avatar_url: services() avatar_url: services
.rooms .rooms
.state_accessor .state_accessor
.get_avatar(&room_id)? .get_avatar(&room_id)?
.into_option() .into_option()
.unwrap_or_default() .unwrap_or_default()
.url, .url,
guest_can_join: services().rooms.state_accessor.guest_can_join(&room_id)?, guest_can_join: services.rooms.state_accessor.guest_can_join(&room_id)?,
name: services() name: services
.rooms .rooms
.state_accessor .state_accessor
.get_name(&room_id) .get_name(&room_id)
.unwrap_or(None), .unwrap_or(None),
num_joined_members: services() num_joined_members: services
.rooms .rooms
.state_cache .state_cache
.room_joined_count(&room_id) .room_joined_count(&room_id)
@ -132,21 +132,21 @@ pub(crate) async fn get_room_summary(
}) })
.try_into() .try_into()
.expect("user count should not be that big"), .expect("user count should not be that big"),
topic: services() topic: services
.rooms .rooms
.state_accessor .state_accessor
.get_room_topic(&room_id) .get_room_topic(&room_id)
.unwrap_or(None), .unwrap_or(None),
world_readable: services() world_readable: services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&room_id) .is_world_readable(&room_id)
.unwrap_or(false), .unwrap_or(false),
join_rule: services().rooms.state_accessor.get_join_rule(&room_id)?.0, join_rule: services.rooms.state_accessor.get_join_rule(&room_id)?.0,
room_type: services().rooms.state_accessor.get_room_type(&room_id)?, room_type: services.rooms.state_accessor.get_room_type(&room_id)?,
room_version: Some(services().rooms.state.get_room_version(&room_id)?), room_version: Some(services.rooms.state.get_room_version(&room_id)?),
membership: if let Some(sender_user) = sender_user { membership: if let Some(sender_user) = sender_user {
services() services
.rooms .rooms
.state_accessor .state_accessor
.get_member(&room_id, sender_user)? .get_member(&room_id, sender_user)?
@ -154,7 +154,7 @@ pub(crate) async fn get_room_summary(
} else { } else {
None None
}, },
encryption: services() encryption: services
.rooms .rooms
.state_accessor .state_accessor
.get_room_encryption(&room_id) .get_room_encryption(&room_id)

View file

@ -1,6 +1,6 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::{response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json};
use ruma::api::client::{ use ruma::api::client::{
discovery::{ discovery::{
discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo}, discover_homeserver::{self, HomeserverInfo, SlidingSyncProxyInfo},
@ -10,7 +10,7 @@ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
}; };
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `GET /_matrix/client/versions` /// # `GET /_matrix/client/versions`
/// ///
@ -62,9 +62,9 @@ pub(crate) async fn get_supported_versions_route(
/// ///
/// Returns the .well-known URL if it is configured, otherwise returns 404. /// Returns the .well-known URL if it is configured, otherwise returns 404.
pub(crate) async fn well_known_client( pub(crate) async fn well_known_client(
_body: Ruma<discover_homeserver::Request>, State(services): State<crate::State>, _body: Ruma<discover_homeserver::Request>,
) -> Result<discover_homeserver::Response> { ) -> Result<discover_homeserver::Response> {
let client_url = match services().globals.well_known_client() { let client_url = match services.globals.well_known_client() {
Some(url) => url.to_string(), Some(url) => url.to_string(),
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
}; };
@ -84,22 +84,24 @@ pub(crate) async fn well_known_client(
/// # `GET /.well-known/matrix/support` /// # `GET /.well-known/matrix/support`
/// ///
/// Server support contact and support page of a homeserver's domain. /// Server support contact and support page of a homeserver's domain.
pub(crate) async fn well_known_support(_body: Ruma<discover_support::Request>) -> Result<discover_support::Response> { pub(crate) async fn well_known_support(
let support_page = services() State(services): State<crate::State>, _body: Ruma<discover_support::Request>,
) -> Result<discover_support::Response> {
let support_page = services
.globals .globals
.well_known_support_page() .well_known_support_page()
.as_ref() .as_ref()
.map(ToString::to_string); .map(ToString::to_string);
let role = services().globals.well_known_support_role().clone(); let role = services.globals.well_known_support_role().clone();
// support page or role must be either defined for this to be valid // support page or role must be either defined for this to be valid
if support_page.is_none() && role.is_none() { if support_page.is_none() && role.is_none() {
return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Not found."));
} }
let email_address = services().globals.well_known_support_email().clone(); let email_address = services.globals.well_known_support_email().clone();
let matrix_id = services().globals.well_known_support_mxid().clone(); let matrix_id = services.globals.well_known_support_mxid().clone();
// if a role is specified, an email address or matrix id is required // if a role is specified, an email address or matrix id is required
if role.is_some() && (email_address.is_none() && matrix_id.is_none()) { if role.is_some() && (email_address.is_none() && matrix_id.is_none()) {
@ -134,10 +136,10 @@ pub(crate) async fn well_known_support(_body: Ruma<discover_support::Request>) -
/// ///
/// Endpoint provided by sliding sync proxy used by some clients such as Element /// Endpoint provided by sliding sync proxy used by some clients such as Element
/// Web as a non-standard health check. /// Web as a non-standard health check.
pub(crate) async fn syncv3_client_server_json() -> Result<impl IntoResponse> { pub(crate) async fn syncv3_client_server_json(State(services): State<crate::State>) -> Result<impl IntoResponse> {
let server_url = match services().globals.well_known_client() { let server_url = match services.globals.well_known_client() {
Some(url) => url.to_string(), Some(url) => url.to_string(),
None => match services().globals.well_known_server() { None => match services.globals.well_known_server() {
Some(url) => url.to_string(), Some(url) => url.to_string(),
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
}, },
@ -165,8 +167,8 @@ pub(crate) async fn conduwuit_server_version() -> Result<impl IntoResponse> {
/// conduwuit-specific API to return the amount of users registered on this /// conduwuit-specific API to return the amount of users registered on this
/// homeserver. Endpoint is disabled if federation is disabled for privacy. This /// homeserver. Endpoint is disabled if federation is disabled for privacy. This
/// only includes active users (not deactivated, no guests, etc) /// only includes active users (not deactivated, no guests, etc)
pub(crate) async fn conduwuit_local_user_count() -> Result<impl IntoResponse> { pub(crate) async fn conduwuit_local_user_count(State(services): State<crate::State>) -> Result<impl IntoResponse> {
let user_count = services().users.list_local_users()?.len(); let user_count = services.users.list_local_users()?.len();
Ok(Json(serde_json::json!({ Ok(Json(serde_json::json!({
"count": user_count "count": user_count

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use ruma::{ use ruma::{
api::client::user_directory::search_users, api::client::user_directory::search_users,
events::{ events::{
@ -6,7 +7,7 @@ use ruma::{
}, },
}; };
use crate::{services, Result, Ruma}; use crate::{Result, Ruma};
/// # `POST /_matrix/client/r0/user_directory/search` /// # `POST /_matrix/client/r0/user_directory/search`
/// ///
@ -14,18 +15,20 @@ use crate::{services, Result, Ruma};
/// ///
/// - Hides any local users that aren't in any public rooms (i.e. those that /// - Hides any local users that aren't in any public rooms (i.e. those that
/// have the join rule set to public) and don't share a room with the sender /// have the join rule set to public) and don't share a room with the sender
pub(crate) async fn search_users_route(body: Ruma<search_users::v3::Request>) -> Result<search_users::v3::Response> { pub(crate) async fn search_users_route(
State(services): State<crate::State>, body: Ruma<search_users::v3::Request>,
) -> Result<search_users::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10 let limit = usize::try_from(body.limit).unwrap_or(10); // default limit is 10
let mut users = services().users.iter().filter_map(|user_id| { let mut users = services.users.iter().filter_map(|user_id| {
// Filter out buggy users (they should not exist, but you never know...) // Filter out buggy users (they should not exist, but you never know...)
let user_id = user_id.ok()?; let user_id = user_id.ok()?;
let user = search_users::v3::User { let user = search_users::v3::User {
user_id: user_id.clone(), user_id: user_id.clone(),
display_name: services().users.displayname(&user_id).ok()?, display_name: services.users.displayname(&user_id).ok()?,
avatar_url: services().users.avatar_url(&user_id).ok()?, avatar_url: services.users.avatar_url(&user_id).ok()?,
}; };
let user_id_matches = user let user_id_matches = user
@ -50,13 +53,13 @@ pub(crate) async fn search_users_route(body: Ruma<search_users::v3::Request>) ->
// It's a matching user, but is the sender allowed to see them? // It's a matching user, but is the sender allowed to see them?
let mut user_visible = false; let mut user_visible = false;
let user_is_in_public_rooms = services() let user_is_in_public_rooms = services
.rooms .rooms
.state_cache .state_cache
.rooms_joined(&user_id) .rooms_joined(&user_id)
.filter_map(Result::ok) .filter_map(Result::ok)
.any(|room| { .any(|room| {
services() services
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(&room, &StateEventType::RoomJoinRules, "") .room_state_get(&room, &StateEventType::RoomJoinRules, "")
@ -71,7 +74,7 @@ pub(crate) async fn search_users_route(body: Ruma<search_users::v3::Request>) ->
if user_is_in_public_rooms { if user_is_in_public_rooms {
user_visible = true; user_visible = true;
} else { } else {
let user_is_in_shared_rooms = services() let user_is_in_shared_rooms = services
.rooms .rooms
.user .user
.get_shared_rooms(vec![sender_user.clone(), user_id]) .get_shared_rooms(vec![sender_user.clone(), user_id])

View file

@ -1,12 +1,13 @@
use std::time::{Duration, SystemTime}; use std::time::{Duration, SystemTime};
use axum::extract::State;
use base64::{engine::general_purpose, Engine as _}; use base64::{engine::general_purpose, Engine as _};
use conduit::utils; use conduit::utils;
use hmac::{Hmac, Mac}; use hmac::{Hmac, Mac};
use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch, UserId}; use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch, UserId};
use sha1::Sha1; use sha1::Sha1;
use crate::{services, Result, Ruma}; use crate::{Result, Ruma};
const RANDOM_USER_ID_LENGTH: usize = 10; const RANDOM_USER_ID_LENGTH: usize = 10;
@ -16,14 +17,14 @@ type HmacSha1 = Hmac<Sha1>;
/// ///
/// TODO: Returns information about the recommended turn server. /// TODO: Returns information about the recommended turn server.
pub(crate) async fn turn_server_route( pub(crate) async fn turn_server_route(
body: Ruma<get_turn_server_info::v3::Request>, State(services): State<crate::State>, body: Ruma<get_turn_server_info::v3::Request>,
) -> Result<get_turn_server_info::v3::Response> { ) -> Result<get_turn_server_info::v3::Response> {
let turn_secret = services().globals.turn_secret().clone(); let turn_secret = services.globals.turn_secret().clone();
let (username, password) = if !turn_secret.is_empty() { let (username, password) = if !turn_secret.is_empty() {
let expiry = SecondsSinceUnixEpoch::from_system_time( let expiry = SecondsSinceUnixEpoch::from_system_time(
SystemTime::now() SystemTime::now()
.checked_add(Duration::from_secs(services().globals.turn_ttl())) .checked_add(Duration::from_secs(services.globals.turn_ttl()))
.expect("TURN TTL should not get this high"), .expect("TURN TTL should not get this high"),
) )
.expect("time is valid"); .expect("time is valid");
@ -31,7 +32,7 @@ pub(crate) async fn turn_server_route(
let user = body.sender_user.unwrap_or_else(|| { let user = body.sender_user.unwrap_or_else(|| {
UserId::parse_with_server_name( UserId::parse_with_server_name(
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
&services().globals.config.server_name, &services.globals.config.server_name,
) )
.unwrap() .unwrap()
}); });
@ -46,15 +47,15 @@ pub(crate) async fn turn_server_route(
(username, password) (username, password)
} else { } else {
( (
services().globals.turn_username().clone(), services.globals.turn_username().clone(),
services().globals.turn_password().clone(), services.globals.turn_password().clone(),
) )
}; };
Ok(get_turn_server_info::v3::Response { Ok(get_turn_server_info::v3::Response {
username, username,
password, password,
uris: services().globals.turn_uris().to_vec(), uris: services.globals.turn_uris().to_vec(),
ttl: Duration::from_secs(services().globals.turn_ttl()), ttl: Duration::from_secs(services.globals.turn_ttl()),
}) })
} }

View file

@ -2,11 +2,12 @@ use std::{mem, ops::Deref};
use axum::{async_trait, body::Body, extract::FromRequest}; use axum::{async_trait, body::Body, extract::FromRequest};
use bytes::{BufMut, BytesMut}; use bytes::{BufMut, BytesMut};
use conduit::{debug, err, trace, Error, Result}; use conduit::{debug, err, trace, utils::string::EMPTY, Error, Result};
use ruma::{api::IncomingRequest, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId}; use ruma::{api::IncomingRequest, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId};
use service::Services;
use super::{auth, auth::Auth, request, request::Request}; use super::{auth, auth::Auth, request, request::Request};
use crate::{service::appservice::RegistrationInfo, services}; use crate::service::appservice::RegistrationInfo;
/// Extractor for Ruma request structs /// Extractor for Ruma request structs
pub(crate) struct Args<T> { pub(crate) struct Args<T> {
@ -42,11 +43,12 @@ where
type Rejection = Error; type Rejection = Error;
async fn from_request(request: hyper::Request<Body>, _: &S) -> Result<Self, Self::Rejection> { async fn from_request(request: hyper::Request<Body>, _: &S) -> Result<Self, Self::Rejection> {
let mut request = request::from(request).await?; let services = service::services(); // ???
let mut request = request::from(services, request).await?;
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok(); let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&request.body).ok();
let auth = auth::auth(&mut request, &json_body, &T::METADATA).await?; let auth = auth::auth(services, &mut request, &json_body, &T::METADATA).await?;
Ok(Self { Ok(Self {
body: make_body::<T>(&mut request, &mut json_body, &auth)?, body: make_body::<T>(services, &mut request, &mut json_body, &auth)?,
origin: auth.origin, origin: auth.origin,
sender_user: auth.sender_user, sender_user: auth.sender_user,
sender_device: auth.sender_device, sender_device: auth.sender_device,
@ -62,13 +64,16 @@ impl<T> Deref for Args<T> {
fn deref(&self) -> &Self::Target { &self.body } fn deref(&self) -> &Self::Target { &self.body }
} }
fn make_body<T>(request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth) -> Result<T> fn make_body<T>(
services: &Services, request: &mut Request, json_body: &mut Option<CanonicalJsonValue>, auth: &Auth,
) -> Result<T>
where where
T: IncomingRequest, T: IncomingRequest,
{ {
let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body { let body = if let Some(CanonicalJsonValue::Object(json_body)) = json_body {
let user_id = auth.sender_user.clone().unwrap_or_else(|| { let user_id = auth.sender_user.clone().unwrap_or_else(|| {
UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid") let server_name = services.globals.server_name();
UserId::parse_with_server_name(EMPTY, server_name).expect("valid user_id")
}); });
let uiaa_request = json_body let uiaa_request = json_body
@ -77,9 +82,9 @@ where
.and_then(|auth| auth.get("session")) .and_then(|auth| auth.get("session"))
.and_then(|session| session.as_str()) .and_then(|session| session.as_str())
.and_then(|session| { .and_then(|session| {
services().uiaa.get_uiaa_request( services.uiaa.get_uiaa_request(
&user_id, &user_id,
&auth.sender_device.clone().unwrap_or_else(|| "".into()), &auth.sender_device.clone().unwrap_or_else(|| EMPTY.into()),
session, session,
) )
}); });

View file

@ -6,17 +6,17 @@ use axum_extra::{
typed_header::TypedHeaderRejectionReason, typed_header::TypedHeaderRejectionReason,
TypedHeader, TypedHeader,
}; };
use conduit::Err; use conduit::{warn, Err, Error, Result};
use http::uri::PathAndQuery; use http::uri::PathAndQuery;
use ruma::{ use ruma::{
api::{client::error::ErrorKind, AuthScheme, Metadata}, api::{client::error::ErrorKind, AuthScheme, Metadata},
server_util::authorization::XMatrix, server_util::authorization::XMatrix,
CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId, UserId,
}; };
use tracing::warn; use service::Services;
use super::request::Request; use super::request::Request;
use crate::{service::appservice::RegistrationInfo, services, Error, Result}; use crate::service::appservice::RegistrationInfo;
enum Token { enum Token {
Appservice(Box<RegistrationInfo>), Appservice(Box<RegistrationInfo>),
@ -33,7 +33,7 @@ pub(super) struct Auth {
} }
pub(super) async fn auth( pub(super) async fn auth(
request: &mut Request, json_body: &Option<CanonicalJsonValue>, metadata: &Metadata, services: &Services, request: &mut Request, json_body: &Option<CanonicalJsonValue>, metadata: &Metadata,
) -> Result<Auth> { ) -> Result<Auth> {
let bearer: Option<TypedHeader<Authorization<Bearer>>> = request.parts.extract().await?; let bearer: Option<TypedHeader<Authorization<Bearer>>> = request.parts.extract().await?;
let token = match &bearer { let token = match &bearer {
@ -42,9 +42,9 @@ pub(super) async fn auth(
}; };
let token = if let Some(token) = token { let token = if let Some(token) = token {
if let Some(reg_info) = services().appservice.find_from_token(token).await { if let Some(reg_info) = services.appservice.find_from_token(token).await {
Token::Appservice(Box::new(reg_info)) Token::Appservice(Box::new(reg_info))
} else if let Some((user_id, device_id)) = services().users.find_from_token(token)? { } else if let Some((user_id, device_id)) = services.users.find_from_token(token)? {
Token::User((user_id, OwnedDeviceId::from(device_id))) Token::User((user_id, OwnedDeviceId::from(device_id)))
} else { } else {
Token::Invalid Token::Invalid
@ -57,7 +57,7 @@ pub(super) async fn auth(
match request.parts.uri.path() { match request.parts.uri.path() {
// TODO: can we check this better? // TODO: can we check this better?
"/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => {
if !services() if !services
.globals .globals
.config .config
.allow_public_room_directory_without_auth .allow_public_room_directory_without_auth
@ -98,7 +98,7 @@ pub(super) async fn auth(
)) ))
} }
}, },
(AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(request, info)?), (AuthScheme::AccessToken, Token::Appservice(info)) => Ok(auth_appservice(services, request, info)?),
(AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => { (AuthScheme::None | AuthScheme::AccessTokenOptional | AuthScheme::AppserviceToken, Token::Appservice(info)) => {
Ok(Auth { Ok(Auth {
origin: None, origin: None,
@ -110,7 +110,7 @@ pub(super) async fn auth(
(AuthScheme::AccessToken, Token::None) => match request.parts.uri.path() { (AuthScheme::AccessToken, Token::None) => match request.parts.uri.path() {
// TODO: can we check this better? // TODO: can we check this better?
"/_matrix/client/v3/voip/turnServer" | "/_matrix/client/r0/voip/turnServer" => { "/_matrix/client/v3/voip/turnServer" | "/_matrix/client/r0/voip/turnServer" => {
if services().globals.config.turn_allow_guests { if services.globals.config.turn_allow_guests {
Ok(Auth { Ok(Auth {
origin: None, origin: None,
sender_user: None, sender_user: None,
@ -132,7 +132,7 @@ pub(super) async fn auth(
sender_device: Some(device_id), sender_device: Some(device_id),
appservice_info: None, appservice_info: None,
}), }),
(AuthScheme::ServerSignatures, Token::None) => Ok(auth_server(request, json_body).await?), (AuthScheme::ServerSignatures, Token::None) => Ok(auth_server(services, request, json_body).await?),
(AuthScheme::None | AuthScheme::AppserviceToken | AuthScheme::AccessTokenOptional, Token::None) => Ok(Auth { (AuthScheme::None | AuthScheme::AppserviceToken | AuthScheme::AccessTokenOptional, Token::None) => Ok(Auth {
sender_user: None, sender_user: None,
sender_device: None, sender_device: None,
@ -150,7 +150,7 @@ pub(super) async fn auth(
} }
} }
fn auth_appservice(request: &Request, info: Box<RegistrationInfo>) -> Result<Auth> { fn auth_appservice(services: &Services, request: &Request, info: Box<RegistrationInfo>) -> Result<Auth> {
let user_id = request let user_id = request
.query .query
.user_id .user_id
@ -159,7 +159,7 @@ fn auth_appservice(request: &Request, info: Box<RegistrationInfo>) -> Result<Aut
|| { || {
UserId::parse_with_server_name( UserId::parse_with_server_name(
info.registration.sender_localpart.as_str(), info.registration.sender_localpart.as_str(),
services().globals.server_name(), services.globals.server_name(),
) )
}, },
UserId::parse, UserId::parse,
@ -170,7 +170,7 @@ fn auth_appservice(request: &Request, info: Box<RegistrationInfo>) -> Result<Aut
return Err(Error::BadRequest(ErrorKind::Exclusive, "User is not in namespace.")); return Err(Error::BadRequest(ErrorKind::Exclusive, "User is not in namespace."));
} }
if !services().users.exists(&user_id)? { if !services.users.exists(&user_id)? {
return Err(Error::BadRequest(ErrorKind::forbidden(), "User does not exist.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "User does not exist."));
} }
@ -182,8 +182,10 @@ fn auth_appservice(request: &Request, info: Box<RegistrationInfo>) -> Result<Aut
}) })
} }
async fn auth_server(request: &mut Request, json_body: &Option<CanonicalJsonValue>) -> Result<Auth> { async fn auth_server(
if !services().globals.allow_federation() { services: &Services, request: &mut Request, json_body: &Option<CanonicalJsonValue>,
) -> Result<Auth> {
if !services.globals.allow_federation() {
return Err!(Config("allow_federation", "Federation is disabled.")); return Err!(Config("allow_federation", "Federation is disabled."));
} }
@ -216,7 +218,7 @@ async fn auth_server(request: &mut Request, json_body: &Option<CanonicalJsonValu
), ),
)]); )]);
let server_destination = services().globals.server_name().as_str().to_owned(); let server_destination = services.globals.server_name().as_str().to_owned();
if let Some(destination) = x_matrix.destination.as_ref() { if let Some(destination) = x_matrix.destination.as_ref() {
if destination != &server_destination { if destination != &server_destination {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Invalid authorization.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Invalid authorization."));
@ -247,7 +249,7 @@ async fn auth_server(request: &mut Request, json_body: &Option<CanonicalJsonValu
request_map.insert("content".to_owned(), json_body.clone()); request_map.insert("content".to_owned(), json_body.clone());
}; };
let keys_result = services() let keys_result = services
.rooms .rooms
.event_handler .event_handler
.fetch_signing_keys_for_server(origin, vec![x_matrix.key.to_string()]) .fetch_signing_keys_for_server(origin, vec![x_matrix.key.to_string()])

View file

@ -2,11 +2,10 @@ use std::str;
use axum::{extract::Path, RequestExt, RequestPartsExt}; use axum::{extract::Path, RequestExt, RequestPartsExt};
use bytes::Bytes; use bytes::Bytes;
use conduit::err; use conduit::{err, Result};
use http::request::Parts; use http::request::Parts;
use serde::Deserialize; use serde::Deserialize;
use service::Services;
use crate::{services, Result};
#[derive(Deserialize)] #[derive(Deserialize)]
pub(super) struct QueryParams { pub(super) struct QueryParams {
@ -21,7 +20,7 @@ pub(super) struct Request {
pub(super) parts: Parts, pub(super) parts: Parts,
} }
pub(super) async fn from(request: hyper::Request<axum::body::Body>) -> Result<Request> { pub(super) async fn from(services: &Services, request: hyper::Request<axum::body::Body>) -> Result<Request> {
let limited = request.with_limited_body(); let limited = request.with_limited_body();
let (mut parts, body) = limited.into_parts(); let (mut parts, body) = limited.into_parts();
@ -30,7 +29,7 @@ pub(super) async fn from(request: hyper::Request<axum::body::Body>) -> Result<Re
let query = let query =
serde_html_form::from_str(query).map_err(|e| err!(Request(Unknown("Failed to read query parameters: {e}"))))?; serde_html_form::from_str(query).map_err(|e| err!(Request(Unknown("Failed to read query parameters: {e}"))))?;
let max_body_size = services().globals.config.max_request_size; let max_body_size = services.globals.config.max_request_size;
let body = axum::body::to_bytes(body, max_body_size) let body = axum::body::to_bytes(body, max_body_size)
.await .await

View file

@ -1,9 +1,10 @@
use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::backfill::get_backfill}, api::{client::error::ErrorKind, federation::backfill::get_backfill},
uint, user_id, MilliSecondsSinceUnixEpoch, uint, user_id, MilliSecondsSinceUnixEpoch,
}; };
use service::{sending::convert_to_outgoing_federation_event, services}; use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma; use crate::Ruma;
@ -11,19 +12,21 @@ use crate::Ruma;
/// ///
/// Retrieves events from before the sender joined the room, if the room's /// Retrieves events from before the sender joined the room, if the room's
/// history visibility allows. /// history visibility allows.
pub(crate) async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) -> Result<get_backfill::v1::Response> { pub(crate) async fn get_backfill_route(
State(services): State<crate::State>, body: Ruma<get_backfill::v1::Request>,
) -> Result<get_backfill::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)?;
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&body.room_id)? .is_world_readable(&body.room_id)?
&& !services() && !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(origin, &body.room_id)? .server_in_room(origin, &body.room_id)?
@ -34,7 +37,7 @@ pub(crate) async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) ->
let until = body let until = body
.v .v
.iter() .iter()
.map(|event_id| services().rooms.timeline.get_pdu_count(event_id)) .map(|event_id| services.rooms.timeline.get_pdu_count(event_id))
.filter_map(|r| r.ok().flatten()) .filter_map(|r| r.ok().flatten())
.max() .max()
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event not found."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Event not found."))?;
@ -45,7 +48,7 @@ pub(crate) async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) ->
.try_into() .try_into()
.expect("UInt could not be converted to usize"); .expect("UInt could not be converted to usize");
let all_events = services() let all_events = services
.rooms .rooms
.timeline .timeline
.pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)? .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)?
@ -55,20 +58,20 @@ pub(crate) async fn get_backfill_route(body: Ruma<get_backfill::v1::Request>) ->
.filter_map(Result::ok) .filter_map(Result::ok)
.filter(|(_, e)| { .filter(|(_, e)| {
matches!( matches!(
services() services
.rooms .rooms
.state_accessor .state_accessor
.server_can_see_event(origin, &e.room_id, &e.event_id,), .server_can_see_event(origin, &e.room_id, &e.event_id,),
Ok(true), Ok(true),
) )
}) })
.map(|(_, pdu)| services().rooms.timeline.get_pdu_json(&pdu.event_id)) .map(|(_, pdu)| services.rooms.timeline.get_pdu_json(&pdu.event_id))
.filter_map(|r| r.ok().flatten()) .filter_map(|r| r.ok().flatten())
.map(convert_to_outgoing_federation_event) .map(convert_to_outgoing_federation_event)
.collect(); .collect();
Ok(get_backfill::v1::Response { Ok(get_backfill::v1::Response {
origin: services().globals.server_name().to_owned(), origin: services.globals.server_name().to_owned(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdus: events, pdus: events,
}) })

View file

@ -1,9 +1,10 @@
use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::event::get_event}, api::{client::error::ErrorKind, federation::event::get_event},
MilliSecondsSinceUnixEpoch, RoomId, MilliSecondsSinceUnixEpoch, RoomId,
}; };
use service::{sending::convert_to_outgoing_federation_event, services}; use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma; use crate::Ruma;
@ -13,10 +14,12 @@ use crate::Ruma;
/// ///
/// - Only works if a user of this server is currently invited or joined the /// - Only works if a user of this server is currently invited or joined the
/// room /// room
pub(crate) async fn get_event_route(body: Ruma<get_event::v1::Request>) -> Result<get_event::v1::Response> { pub(crate) async fn get_event_route(
State(services): State<crate::State>, body: Ruma<get_event::v1::Request>,
) -> Result<get_event::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
let event = services() let event = services
.rooms .rooms
.timeline .timeline
.get_pdu_json(&body.event_id)? .get_pdu_json(&body.event_id)?
@ -30,16 +33,13 @@ pub(crate) async fn get_event_route(body: Ruma<get_event::v1::Request>) -> Resul
let room_id = let room_id =
<&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?;
if !services().rooms.state_accessor.is_world_readable(room_id)? if !services.rooms.state_accessor.is_world_readable(room_id)?
&& !services() && !services.rooms.state_cache.server_in_room(origin, room_id)?
.rooms
.state_cache
.server_in_room(origin, room_id)?
{ {
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room."));
} }
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.server_can_see_event(origin, room_id, &body.event_id)? .server_can_see_event(origin, room_id, &body.event_id)?
@ -48,7 +48,7 @@ pub(crate) async fn get_event_route(body: Ruma<get_event::v1::Request>) -> Resul
} }
Ok(get_event::v1::Response { Ok(get_event::v1::Response {
origin: services().globals.server_name().to_owned(), origin: services.globals.server_name().to_owned(),
origin_server_ts: MilliSecondsSinceUnixEpoch::now(), origin_server_ts: MilliSecondsSinceUnixEpoch::now(),
pdu: convert_to_outgoing_federation_event(event), pdu: convert_to_outgoing_federation_event(event),
}) })

View file

@ -1,11 +1,12 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::authorization::get_event_authorization}, api::{client::error::ErrorKind, federation::authorization::get_event_authorization},
RoomId, RoomId,
}; };
use service::{sending::convert_to_outgoing_federation_event, services}; use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma; use crate::Ruma;
@ -15,20 +16,20 @@ use crate::Ruma;
/// ///
/// - This does not include the event itself /// - This does not include the event itself
pub(crate) async fn get_event_authorization_route( pub(crate) async fn get_event_authorization_route(
body: Ruma<get_event_authorization::v1::Request>, State(services): State<crate::State>, body: Ruma<get_event_authorization::v1::Request>,
) -> Result<get_event_authorization::v1::Response> { ) -> Result<get_event_authorization::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)?;
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&body.room_id)? .is_world_readable(&body.room_id)?
&& !services() && !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(origin, &body.room_id)? .server_in_room(origin, &body.room_id)?
@ -36,7 +37,7 @@ pub(crate) async fn get_event_authorization_route(
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room."));
} }
let event = services() let event = services
.rooms .rooms
.timeline .timeline
.get_pdu_json(&body.event_id)? .get_pdu_json(&body.event_id)?
@ -50,7 +51,7 @@ pub(crate) async fn get_event_authorization_route(
let room_id = let room_id =
<&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?; <&RoomId>::try_from(room_id_str).map_err(|_| Error::bad_database("Invalid room_id in event in database."))?;
let auth_chain_ids = services() let auth_chain_ids = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(room_id, vec![Arc::from(&*body.event_id)]) .event_ids_iter(room_id, vec![Arc::from(&*body.event_id)])
@ -58,7 +59,7 @@ pub(crate) async fn get_event_authorization_route(
Ok(get_event_authorization::v1::Response { Ok(get_event_authorization::v1::Response {
auth_chain: auth_chain_ids auth_chain: auth_chain_ids
.filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok()?) .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok()?)
.map(convert_to_outgoing_federation_event) .map(convert_to_outgoing_federation_event)
.collect(), .collect(),
}) })

View file

@ -1,9 +1,10 @@
use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::event::get_missing_events}, api::{client::error::ErrorKind, federation::event::get_missing_events},
OwnedEventId, RoomId, OwnedEventId, RoomId,
}; };
use service::{sending::convert_to_outgoing_federation_event, services}; use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma; use crate::Ruma;
@ -11,20 +12,20 @@ use crate::Ruma;
/// ///
/// Retrieves events that the sender is missing. /// Retrieves events that the sender is missing.
pub(crate) async fn get_missing_events_route( pub(crate) async fn get_missing_events_route(
body: Ruma<get_missing_events::v1::Request>, State(services): State<crate::State>, body: Ruma<get_missing_events::v1::Request>,
) -> Result<get_missing_events::v1::Response> { ) -> Result<get_missing_events::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)?;
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&body.room_id)? .is_world_readable(&body.room_id)?
&& !services() && !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(origin, &body.room_id)? .server_in_room(origin, &body.room_id)?
@ -43,7 +44,7 @@ pub(crate) async fn get_missing_events_route(
let mut i: usize = 0; let mut i: usize = 0;
while i < queued_events.len() && events.len() < limit { while i < queued_events.len() && events.len() < limit {
if let Some(pdu) = services().rooms.timeline.get_pdu_json(&queued_events[i])? { if let Some(pdu) = services.rooms.timeline.get_pdu_json(&queued_events[i])? {
let room_id_str = pdu let room_id_str = pdu
.get("room_id") .get("room_id")
.and_then(|val| val.as_str()) .and_then(|val| val.as_str())
@ -61,7 +62,7 @@ pub(crate) async fn get_missing_events_route(
continue; continue;
} }
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.server_can_see_event(origin, &body.room_id, &queued_events[i])? .server_can_see_event(origin, &body.room_id, &queued_events[i])?

View file

@ -1,16 +1,19 @@
use axum::extract::State;
use ruma::api::{client::error::ErrorKind, federation::space::get_hierarchy}; use ruma::api::{client::error::ErrorKind, federation::space::get_hierarchy};
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `GET /_matrix/federation/v1/hierarchy/{roomId}` /// # `GET /_matrix/federation/v1/hierarchy/{roomId}`
/// ///
/// Gets the space tree in a depth-first manner to locate child rooms of a given /// Gets the space tree in a depth-first manner to locate child rooms of a given
/// space. /// space.
pub(crate) async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> { pub(crate) async fn get_hierarchy_route(
State(services): State<crate::State>, body: Ruma<get_hierarchy::v1::Request>,
) -> Result<get_hierarchy::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
if services().rooms.metadata.exists(&body.room_id)? { if services.rooms.metadata.exists(&body.room_id)? {
services() services
.rooms .rooms
.spaces .spaces
.get_federation_hierarchy(&body.room_id, origin, body.suggested_only) .get_federation_hierarchy(&body.room_id, origin, body.suggested_only)

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use conduit::{utils, warn, Error, PduEvent, Result}; use conduit::{utils, warn, Error, PduEvent, Result};
use ruma::{ use ruma::{
@ -6,7 +7,7 @@ use ruma::{
serde::JsonObject, serde::JsonObject,
CanonicalJsonValue, EventId, OwnedUserId, CanonicalJsonValue, EventId, OwnedUserId,
}; };
use service::{sending::convert_to_outgoing_federation_event, server_is_ours, services}; use service::{sending::convert_to_outgoing_federation_event, server_is_ours};
use crate::Ruma; use crate::Ruma;
@ -15,17 +16,18 @@ use crate::Ruma;
/// Invites a remote user to a room. /// Invites a remote user to a room.
#[tracing::instrument(skip_all, fields(%client), name = "invite")] #[tracing::instrument(skip_all, fields(%client), name = "invite")]
pub(crate) async fn create_invite_route( pub(crate) async fn create_invite_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<create_invite::v2::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<create_invite::v2::Request>,
) -> Result<create_invite::v2::Response> { ) -> Result<create_invite::v2::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
// ACL check origin // ACL check origin
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)?;
if !services() if !services
.globals .globals
.supported_room_versions() .supported_room_versions()
.contains(&body.room_version) .contains(&body.room_version)
@ -39,7 +41,7 @@ pub(crate) async fn create_invite_route(
} }
if let Some(server) = body.room_id.server_name() { if let Some(server) = body.room_id.server_name() {
if services() if services
.globals .globals
.config .config
.forbidden_remote_server_names .forbidden_remote_server_names
@ -52,7 +54,7 @@ pub(crate) async fn create_invite_route(
} }
} }
if services() if services
.globals .globals
.config .config
.forbidden_remote_server_names .forbidden_remote_server_names
@ -94,14 +96,14 @@ pub(crate) async fn create_invite_route(
} }
// Make sure we're not ACL'ed from their room. // Make sure we're not ACL'ed from their room.
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(invited_user.server_name(), &body.room_id)?; .acl_check(invited_user.server_name(), &body.room_id)?;
ruma::signatures::hash_and_sign_event( ruma::signatures::hash_and_sign_event(
services().globals.server_name().as_str(), services.globals.server_name().as_str(),
services().globals.keypair(), services.globals.keypair(),
&mut signed_event, &mut signed_event,
&body.room_version, &body.room_version,
) )
@ -127,14 +129,14 @@ pub(crate) async fn create_invite_route(
) )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user ID."))?;
if services().rooms.metadata.is_banned(&body.room_id)? && !services().users.is_admin(&invited_user)? { if services.rooms.metadata.is_banned(&body.room_id)? && !services.users.is_admin(&invited_user)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"This room is banned on this homeserver.", "This room is banned on this homeserver.",
)); ));
} }
if services().globals.block_non_admin_invites() && !services().users.is_admin(&invited_user)? { if services.globals.block_non_admin_invites() && !services.users.is_admin(&invited_user)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"This server does not allow room invites.", "This server does not allow room invites.",
@ -155,12 +157,12 @@ pub(crate) async fn create_invite_route(
// If we are active in the room, the remote server will notify us about the join // If we are active in the room, the remote server will notify us about the join
// via /send // via /send
if !services() if !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(services().globals.server_name(), &body.room_id)? .server_in_room(services.globals.server_name(), &body.room_id)?
{ {
services().rooms.state_cache.update_membership( services.rooms.state_cache.update_membership(
&body.room_id, &body.room_id,
&invited_user, &invited_user,
RoomMemberEventContent::new(MembershipState::Invite), RoomMemberEventContent::new(MembershipState::Invite),

View file

@ -3,7 +3,7 @@ use std::{
time::{Duration, SystemTime}, time::{Duration, SystemTime},
}; };
use axum::{response::IntoResponse, Json}; use axum::{extract::State, response::IntoResponse, Json};
use ruma::{ use ruma::{
api::{ api::{
federation::discovery::{get_server_keys, ServerSigningKeys, VerifyKey}, federation::discovery::{get_server_keys, ServerSigningKeys, VerifyKey},
@ -13,7 +13,7 @@ use ruma::{
MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId,
}; };
use crate::{services, Result}; use crate::Result;
/// # `GET /_matrix/key/v2/server` /// # `GET /_matrix/key/v2/server`
/// ///
@ -23,20 +23,20 @@ use crate::{services, Result};
/// this will be valid forever. /// this will be valid forever.
// Response type for this endpoint is Json because we need to calculate a // Response type for this endpoint is Json because we need to calculate a
// signature for the response // signature for the response
pub(crate) async fn get_server_keys_route() -> Result<impl IntoResponse> { pub(crate) async fn get_server_keys_route(State(services): State<crate::State>) -> Result<impl IntoResponse> {
let verify_keys: BTreeMap<OwnedServerSigningKeyId, VerifyKey> = BTreeMap::from([( let verify_keys: BTreeMap<OwnedServerSigningKeyId, VerifyKey> = BTreeMap::from([(
format!("ed25519:{}", services().globals.keypair().version()) format!("ed25519:{}", services.globals.keypair().version())
.try_into() .try_into()
.expect("found invalid server signing keys in DB"), .expect("found invalid server signing keys in DB"),
VerifyKey { VerifyKey {
key: Base64::new(services().globals.keypair().public_key().to_vec()), key: Base64::new(services.globals.keypair().public_key().to_vec()),
}, },
)]); )]);
let mut response = serde_json::from_slice( let mut response = serde_json::from_slice(
get_server_keys::v2::Response { get_server_keys::v2::Response {
server_key: Raw::new(&ServerSigningKeys { server_key: Raw::new(&ServerSigningKeys {
server_name: services().globals.server_name().to_owned(), server_name: services.globals.server_name().to_owned(),
verify_keys, verify_keys,
old_verify_keys: BTreeMap::new(), old_verify_keys: BTreeMap::new(),
signatures: BTreeMap::new(), signatures: BTreeMap::new(),
@ -56,8 +56,8 @@ pub(crate) async fn get_server_keys_route() -> Result<impl IntoResponse> {
.unwrap(); .unwrap();
ruma::signatures::sign_json( ruma::signatures::sign_json(
services().globals.server_name().as_str(), services.globals.server_name().as_str(),
services().globals.keypair(), services.globals.keypair(),
&mut response, &mut response,
) )
.unwrap(); .unwrap();
@ -71,4 +71,6 @@ pub(crate) async fn get_server_keys_route() -> Result<impl IntoResponse> {
/// ///
/// - Matrix does not support invalidating public keys, so the key returned by /// - Matrix does not support invalidating public keys, so the key returned by
/// this will be valid forever. /// this will be valid forever.
pub(crate) async fn get_server_keys_deprecated_route() -> impl IntoResponse { get_server_keys_route().await } pub(crate) async fn get_server_keys_deprecated_route(State(services): State<crate::State>) -> impl IntoResponse {
get_server_keys_route(State(services)).await
}

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::membership::prepare_join_event}, api::{client::error::ErrorKind, federation::membership::prepare_join_event},
events::{ events::{
@ -12,15 +13,18 @@ use ruma::{
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use tracing::warn; use tracing::warn;
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma}; use crate::{
service::{pdu::PduBuilder, Services},
Error, Result, Ruma,
};
/// # `GET /_matrix/federation/v1/make_join/{roomId}/{userId}` /// # `GET /_matrix/federation/v1/make_join/{roomId}/{userId}`
/// ///
/// Creates a join template. /// Creates a join template.
pub(crate) async fn create_join_event_template_route( pub(crate) async fn create_join_event_template_route(
body: Ruma<prepare_join_event::v1::Request>, State(services): State<crate::State>, body: Ruma<prepare_join_event::v1::Request>,
) -> Result<prepare_join_event::v1::Response> { ) -> Result<prepare_join_event::v1::Response> {
if !services().rooms.metadata.exists(&body.room_id)? { if !services.rooms.metadata.exists(&body.room_id)? {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
} }
@ -33,12 +37,12 @@ pub(crate) async fn create_join_event_template_route(
} }
// ACL check origin server // ACL check origin server
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)?;
if services() if services
.globals .globals
.config .config
.forbidden_remote_server_names .forbidden_remote_server_names
@ -56,7 +60,7 @@ pub(crate) async fn create_join_event_template_route(
} }
if let Some(server) = body.room_id.server_name() { if let Some(server) = body.room_id.server_name() {
if services() if services
.globals .globals
.config .config
.forbidden_remote_server_names .forbidden_remote_server_names
@ -69,25 +73,25 @@ pub(crate) async fn create_join_event_template_route(
} }
} }
let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; let room_version_id = services.rooms.state.get_room_version(&body.room_id)?;
let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
let join_authorized_via_users_server = if (services() let join_authorized_via_users_server = if (services
.rooms .rooms
.state_cache .state_cache
.is_left(&body.user_id, &body.room_id) .is_left(&body.user_id, &body.room_id)
.unwrap_or(true)) .unwrap_or(true))
&& user_can_perform_restricted_join(&body.user_id, &body.room_id, &room_version_id)? && user_can_perform_restricted_join(services, &body.user_id, &body.room_id, &room_version_id)?
{ {
let auth_user = services() let auth_user = services
.rooms .rooms
.state_cache .state_cache
.room_members(&body.room_id) .room_members(&body.room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
.filter(|user| user.server_name() == services().globals.server_name()) .filter(|user| user.server_name() == services.globals.server_name())
.find(|user| { .find(|user| {
services() services
.rooms .rooms
.state_accessor .state_accessor
.user_can_invite(&body.room_id, user, &body.user_id, &state_lock) .user_can_invite(&body.room_id, user, &body.user_id, &state_lock)
@ -106,7 +110,7 @@ pub(crate) async fn create_join_event_template_route(
None None
}; };
let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; let room_version_id = services.rooms.state.get_room_version(&body.room_id)?;
if !body.ver.contains(&room_version_id) { if !body.ver.contains(&room_version_id) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::IncompatibleRoomVersion { ErrorKind::IncompatibleRoomVersion {
@ -128,7 +132,7 @@ pub(crate) async fn create_join_event_template_route(
}) })
.expect("member event is valid value"); .expect("member event is valid value");
let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content, content,
@ -157,12 +161,11 @@ pub(crate) async fn create_join_event_template_route(
/// externally, either by using the state cache or attempting to authorize the /// externally, either by using the state cache or attempting to authorize the
/// event. /// event.
pub(crate) fn user_can_perform_restricted_join( pub(crate) fn user_can_perform_restricted_join(
user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId, services: &Services, user_id: &UserId, room_id: &RoomId, room_version_id: &RoomVersionId,
) -> Result<bool> { ) -> Result<bool> {
use RoomVersionId::*; use RoomVersionId::*;
let join_rules_event = let join_rules_event = services
services()
.rooms .rooms
.state_accessor .state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; .room_state_get(room_id, &StateEventType::RoomJoinRules, "")?;
@ -198,7 +201,7 @@ pub(crate) fn user_can_perform_restricted_join(
} }
}) })
.any(|m| { .any(|m| {
services() services
.rooms .rooms
.state_cache .state_cache
.is_joined(user_id, &m.room_id) .is_joined(user_id, &m.room_id)

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::membership::prepare_leave_event}, api::{client::error::ErrorKind, federation::membership::prepare_leave_event},
@ -9,15 +10,15 @@ use ruma::{
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use super::make_join::maybe_strip_event_id; use super::make_join::maybe_strip_event_id;
use crate::{service::pdu::PduBuilder, services, Ruma}; use crate::{service::pdu::PduBuilder, Ruma};
/// # `PUT /_matrix/federation/v1/make_leave/{roomId}/{eventId}` /// # `PUT /_matrix/federation/v1/make_leave/{roomId}/{eventId}`
/// ///
/// Creates a leave template. /// Creates a leave template.
pub(crate) async fn create_leave_event_template_route( pub(crate) async fn create_leave_event_template_route(
body: Ruma<prepare_leave_event::v1::Request>, State(services): State<crate::State>, body: Ruma<prepare_leave_event::v1::Request>,
) -> Result<prepare_leave_event::v1::Response> { ) -> Result<prepare_leave_event::v1::Response> {
if !services().rooms.metadata.exists(&body.room_id)? { if !services.rooms.metadata.exists(&body.room_id)? {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
} }
@ -30,13 +31,13 @@ pub(crate) async fn create_leave_event_template_route(
} }
// ACL check origin // ACL check origin
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)?;
let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; let room_version_id = services.rooms.state.get_room_version(&body.room_id)?;
let state_lock = services().rooms.state.mutex.lock(&body.room_id).await; let state_lock = services.rooms.state.mutex.lock(&body.room_id).await;
let content = to_raw_value(&RoomMemberEventContent { let content = to_raw_value(&RoomMemberEventContent {
avatar_url: None, avatar_url: None,
blurhash: None, blurhash: None,
@ -49,7 +50,7 @@ pub(crate) async fn create_leave_event_template_route(
}) })
.expect("member event is valid value"); .expect("member event is valid value");
let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( let (_pdu, mut pdu_json) = services.rooms.timeline.create_hash_and_sign_event(
PduBuilder { PduBuilder {
event_type: TimelineEventType::RoomMember, event_type: TimelineEventType::RoomMember,
content, content,

View file

@ -1,16 +1,15 @@
use axum::extract::State;
use ruma::api::federation::openid::get_openid_userinfo; use ruma::api::federation::openid::get_openid_userinfo;
use crate::{services, Result, Ruma}; use crate::{Result, Ruma};
/// # `GET /_matrix/federation/v1/openid/userinfo` /// # `GET /_matrix/federation/v1/openid/userinfo`
/// ///
/// Get information about the user that generated the OpenID token. /// Get information about the user that generated the OpenID token.
pub(crate) async fn get_openid_userinfo_route( pub(crate) async fn get_openid_userinfo_route(
body: Ruma<get_openid_userinfo::v1::Request>, State(services): State<crate::State>, body: Ruma<get_openid_userinfo::v1::Request>,
) -> Result<get_openid_userinfo::v1::Response> { ) -> Result<get_openid_userinfo::v1::Response> {
Ok(get_openid_userinfo::v1::Response::new( Ok(get_openid_userinfo::v1::Response::new(
services() services.users.find_from_openid_token(&body.access_token)?,
.users
.find_from_openid_token(&body.access_token)?,
)) ))
} }

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use axum_client_ip::InsecureClientIp; use axum_client_ip::InsecureClientIp;
use ruma::{ use ruma::{
api::{ api::{
@ -7,16 +8,17 @@ use ruma::{
directory::Filter, directory::Filter,
}; };
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `POST /_matrix/federation/v1/publicRooms` /// # `POST /_matrix/federation/v1/publicRooms`
/// ///
/// Lists the public rooms on this server. /// Lists the public rooms on this server.
#[tracing::instrument(skip_all, fields(%client), name = "publicrooms")] #[tracing::instrument(skip_all, fields(%client), name = "publicrooms")]
pub(crate) async fn get_public_rooms_filtered_route( pub(crate) async fn get_public_rooms_filtered_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_public_rooms_filtered::v1::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_public_rooms_filtered::v1::Request>,
) -> Result<get_public_rooms_filtered::v1::Response> { ) -> Result<get_public_rooms_filtered::v1::Response> {
if !services() if !services
.globals .globals
.allow_public_room_directory_over_federation() .allow_public_room_directory_over_federation()
{ {
@ -24,6 +26,7 @@ pub(crate) async fn get_public_rooms_filtered_route(
} }
let response = crate::client::get_public_rooms_filtered_helper( let response = crate::client::get_public_rooms_filtered_helper(
services,
None, None,
body.limit, body.limit,
body.since.as_deref(), body.since.as_deref(),
@ -46,9 +49,10 @@ pub(crate) async fn get_public_rooms_filtered_route(
/// Lists the public rooms on this server. /// Lists the public rooms on this server.
#[tracing::instrument(skip_all, fields(%client), "publicrooms")] #[tracing::instrument(skip_all, fields(%client), "publicrooms")]
pub(crate) async fn get_public_rooms_route( pub(crate) async fn get_public_rooms_route(
InsecureClientIp(client): InsecureClientIp, body: Ruma<get_public_rooms::v1::Request>, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<get_public_rooms::v1::Request>,
) -> Result<get_public_rooms::v1::Response> { ) -> Result<get_public_rooms::v1::Response> {
if !services() if !services
.globals .globals
.allow_public_room_directory_over_federation() .allow_public_room_directory_over_federation()
{ {
@ -56,6 +60,7 @@ pub(crate) async fn get_public_rooms_route(
} }
let response = crate::client::get_public_rooms_filtered_helper( let response = crate::client::get_public_rooms_filtered_helper(
services,
None, None,
body.limit, body.limit,
body.since.as_deref(), body.since.as_deref(),

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use get_profile_information::v1::ProfileField; use get_profile_information::v1::ProfileField;
use rand::seq::SliceRandom; use rand::seq::SliceRandom;
use ruma::{ use ruma::{
@ -8,21 +9,21 @@ use ruma::{
OwnedServerName, OwnedServerName,
}; };
use crate::{service::server_is_ours, services, Error, Result, Ruma}; use crate::{service::server_is_ours, Error, Result, Ruma};
/// # `GET /_matrix/federation/v1/query/directory` /// # `GET /_matrix/federation/v1/query/directory`
/// ///
/// Resolve a room alias to a room id. /// Resolve a room alias to a room id.
pub(crate) async fn get_room_information_route( pub(crate) async fn get_room_information_route(
body: Ruma<get_room_information::v1::Request>, State(services): State<crate::State>, body: Ruma<get_room_information::v1::Request>,
) -> Result<get_room_information::v1::Response> { ) -> Result<get_room_information::v1::Response> {
let room_id = services() let room_id = services
.rooms .rooms
.alias .alias
.resolve_local_alias(&body.room_alias)? .resolve_local_alias(&body.room_alias)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?;
let mut servers: Vec<OwnedServerName> = services() let mut servers: Vec<OwnedServerName> = services
.rooms .rooms
.state_cache .state_cache
.room_servers(&room_id) .room_servers(&room_id)
@ -37,10 +38,10 @@ pub(crate) async fn get_room_information_route(
// insert our server as the very first choice if in list // insert our server as the very first choice if in list
if let Some(server_index) = servers if let Some(server_index) = servers
.iter() .iter()
.position(|server| server == services().globals.server_name()) .position(|server| server == services.globals.server_name())
{ {
servers.swap_remove(server_index); servers.swap_remove(server_index);
servers.insert(0, services().globals.server_name().to_owned()); servers.insert(0, services.globals.server_name().to_owned());
} }
Ok(get_room_information::v1::Response { Ok(get_room_information::v1::Response {
@ -54,12 +55,9 @@ pub(crate) async fn get_room_information_route(
/// ///
/// Gets information on a profile. /// Gets information on a profile.
pub(crate) async fn get_profile_information_route( pub(crate) async fn get_profile_information_route(
body: Ruma<get_profile_information::v1::Request>, State(services): State<crate::State>, body: Ruma<get_profile_information::v1::Request>,
) -> Result<get_profile_information::v1::Response> { ) -> Result<get_profile_information::v1::Response> {
if !services() if !services.globals.allow_profile_lookup_federation_requests() {
.globals
.allow_profile_lookup_federation_requests()
{
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::forbidden(), ErrorKind::forbidden(),
"Profile lookup over federation is not allowed on this homeserver.", "Profile lookup over federation is not allowed on this homeserver.",
@ -79,18 +77,18 @@ pub(crate) async fn get_profile_information_route(
match &body.field { match &body.field {
Some(ProfileField::DisplayName) => { Some(ProfileField::DisplayName) => {
displayname = services().users.displayname(&body.user_id)?; displayname = services.users.displayname(&body.user_id)?;
}, },
Some(ProfileField::AvatarUrl) => { Some(ProfileField::AvatarUrl) => {
avatar_url = services().users.avatar_url(&body.user_id)?; avatar_url = services.users.avatar_url(&body.user_id)?;
blurhash = services().users.blurhash(&body.user_id)?; blurhash = services.users.blurhash(&body.user_id)?;
}, },
// TODO: what to do with custom // TODO: what to do with custom
Some(_) => {}, Some(_) => {},
None => { None => {
displayname = services().users.displayname(&body.user_id)?; displayname = services.users.displayname(&body.user_id)?;
avatar_url = services().users.avatar_url(&body.user_id)?; avatar_url = services.users.avatar_url(&body.user_id)?;
blurhash = services().users.blurhash(&body.user_id)?; blurhash = services.users.blurhash(&body.user_id)?;
}, },
} }

View file

@ -34,7 +34,7 @@ type ResolvedMap = BTreeMap<OwnedEventId, Result<(), Error>>;
/// Push EDUs and PDUs to this server. /// Push EDUs and PDUs to this server.
#[tracing::instrument(skip_all, fields(%client), name = "send")] #[tracing::instrument(skip_all, fields(%client), name = "send")]
pub(crate) async fn send_transaction_message_route( pub(crate) async fn send_transaction_message_route(
State(services): State<&Services>, InsecureClientIp(client): InsecureClientIp, State(services): State<crate::State>, InsecureClientIp(client): InsecureClientIp,
body: Ruma<send_transaction_message::v1::Request>, body: Ruma<send_transaction_message::v1::Request>,
) -> Result<send_transaction_message::v1::Response> { ) -> Result<send_transaction_message::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");

View file

@ -2,6 +2,7 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::membership::create_join_event}, api::{client::error::ErrorKind, federation::membership::create_join_event},
@ -13,7 +14,7 @@ use ruma::{
}; };
use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use service::{ use service::{
pdu::gen_event_id_canonical_json, sending::convert_to_outgoing_federation_event, services, user_is_local, pdu::gen_event_id_canonical_json, sending::convert_to_outgoing_federation_event, user_is_local, Services,
}; };
use tokio::sync::RwLock; use tokio::sync::RwLock;
use tracing::warn; use tracing::warn;
@ -22,18 +23,18 @@ use crate::Ruma;
/// helper method for /send_join v1 and v2 /// helper method for /send_join v1 and v2
async fn create_join_event( async fn create_join_event(
origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue,
) -> Result<create_join_event::v1::RoomState> { ) -> Result<create_join_event::v1::RoomState> {
if !services().rooms.metadata.exists(room_id)? { if !services.rooms.metadata.exists(room_id)? {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
} }
// ACL check origin server // ACL check origin server
services().rooms.event_handler.acl_check(origin, room_id)?; services.rooms.event_handler.acl_check(origin, room_id)?;
// We need to return the state prior to joining, let's keep a reference to that // We need to return the state prior to joining, let's keep a reference to that
// here // here
let shortstatehash = services() let shortstatehash = services
.rooms .rooms
.state .state
.get_room_shortstatehash(room_id)? .get_room_shortstatehash(room_id)?
@ -44,7 +45,7 @@ async fn create_join_event(
// We do not add the event_id field to the pdu here because of signature and // We do not add the event_id field to the pdu here because of signature and
// hashes checks // hashes checks
let room_version_id = services().rooms.state.get_room_version(room_id)?; let room_version_id = services.rooms.state.get_room_version(room_id)?;
let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { let Ok((event_id, mut value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
// Event could not be converted to canonical json // Event could not be converted to canonical json
@ -96,7 +97,7 @@ async fn create_join_event(
) )
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "sender is not a valid user ID."))?; .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "sender is not a valid user ID."))?;
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(sender.server_name(), room_id)?; .acl_check(sender.server_name(), room_id)?;
@ -128,18 +129,18 @@ async fn create_join_event(
if content if content
.join_authorized_via_users_server .join_authorized_via_users_server
.is_some_and(|user| user_is_local(&user)) .is_some_and(|user| user_is_local(&user))
&& super::user_can_perform_restricted_join(&sender, room_id, &room_version_id).unwrap_or_default() && super::user_can_perform_restricted_join(services, &sender, room_id, &room_version_id).unwrap_or_default()
{ {
ruma::signatures::hash_and_sign_event( ruma::signatures::hash_and_sign_event(
services().globals.server_name().as_str(), services.globals.server_name().as_str(),
services().globals.keypair(), services.globals.keypair(),
&mut value, &mut value,
&room_version_id, &room_version_id,
) )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?;
} }
services() services
.rooms .rooms
.event_handler .event_handler
.fetch_required_signing_keys([&value], &pub_key_map) .fetch_required_signing_keys([&value], &pub_key_map)
@ -155,13 +156,13 @@ async fn create_join_event(
) )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?;
let mutex_lock = services() let mutex_lock = services
.rooms .rooms
.event_handler .event_handler
.mutex_federation .mutex_federation
.lock(room_id) .lock(room_id)
.await; .await;
let pdu_id: Vec<u8> = services() let pdu_id: Vec<u8> = services
.rooms .rooms
.event_handler .event_handler
.handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map) .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true, &pub_key_map)
@ -169,27 +170,27 @@ async fn create_join_event(
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?;
drop(mutex_lock); drop(mutex_lock);
let state_ids = services() let state_ids = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await?; .await?;
let auth_chain_ids = services() let auth_chain_ids = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(room_id, state_ids.values().cloned().collect()) .event_ids_iter(room_id, state_ids.values().cloned().collect())
.await?; .await?;
services().sending.send_pdu_room(room_id, &pdu_id)?; services.sending.send_pdu_room(room_id, &pdu_id)?;
Ok(create_join_event::v1::RoomState { Ok(create_join_event::v1::RoomState {
auth_chain: auth_chain_ids auth_chain: auth_chain_ids
.filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) .filter_map(|id| services.rooms.timeline.get_pdu_json(&id).ok().flatten())
.map(convert_to_outgoing_federation_event) .map(convert_to_outgoing_federation_event)
.collect(), .collect(),
state: state_ids state: state_ids
.iter() .iter()
.filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) .filter_map(|(_, id)| services.rooms.timeline.get_pdu_json(id).ok().flatten())
.map(convert_to_outgoing_federation_event) .map(convert_to_outgoing_federation_event)
.collect(), .collect(),
// Event field is required if the room version supports restricted join rules. // Event field is required if the room version supports restricted join rules.
@ -204,11 +205,11 @@ async fn create_join_event(
/// ///
/// Submits a signed join event. /// Submits a signed join event.
pub(crate) async fn create_join_event_v1_route( pub(crate) async fn create_join_event_v1_route(
body: Ruma<create_join_event::v1::Request>, State(services): State<crate::State>, body: Ruma<create_join_event::v1::Request>,
) -> Result<create_join_event::v1::Response> { ) -> Result<create_join_event::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
if services() if services
.globals .globals
.config .config
.forbidden_remote_server_names .forbidden_remote_server_names
@ -225,7 +226,7 @@ pub(crate) async fn create_join_event_v1_route(
} }
if let Some(server) = body.room_id.server_name() { if let Some(server) = body.room_id.server_name() {
if services() if services
.globals .globals
.config .config
.forbidden_remote_server_names .forbidden_remote_server_names
@ -243,7 +244,7 @@ pub(crate) async fn create_join_event_v1_route(
} }
} }
let room_state = create_join_event(origin, &body.room_id, &body.pdu).await?; let room_state = create_join_event(services, origin, &body.room_id, &body.pdu).await?;
Ok(create_join_event::v1::Response { Ok(create_join_event::v1::Response {
room_state, room_state,
@ -254,11 +255,11 @@ pub(crate) async fn create_join_event_v1_route(
/// ///
/// Submits a signed join event. /// Submits a signed join event.
pub(crate) async fn create_join_event_v2_route( pub(crate) async fn create_join_event_v2_route(
body: Ruma<create_join_event::v2::Request>, State(services): State<crate::State>, body: Ruma<create_join_event::v2::Request>,
) -> Result<create_join_event::v2::Response> { ) -> Result<create_join_event::v2::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
if services() if services
.globals .globals
.config .config
.forbidden_remote_server_names .forbidden_remote_server_names
@ -271,7 +272,7 @@ pub(crate) async fn create_join_event_v2_route(
} }
if let Some(server) = body.room_id.server_name() { if let Some(server) = body.room_id.server_name() {
if services() if services
.globals .globals
.config .config
.forbidden_remote_server_names .forbidden_remote_server_names
@ -288,7 +289,7 @@ pub(crate) async fn create_join_event_v2_route(
auth_chain, auth_chain,
state, state,
event, event,
} = create_join_event(origin, &body.room_id, &body.pdu).await?; } = create_join_event(services, origin, &body.room_id, &body.pdu).await?;
let room_state = create_join_event::v2::RoomState { let room_state = create_join_event::v2::RoomState {
members_omitted: false, members_omitted: false,
auth_chain, auth_chain,

View file

@ -2,6 +2,7 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use axum::extract::State;
use ruma::{ use ruma::{
api::{client::error::ErrorKind, federation::membership::create_leave_event}, api::{client::error::ErrorKind, federation::membership::create_leave_event},
events::{ events::{
@ -14,19 +15,19 @@ use serde_json::value::RawValue as RawJsonValue;
use tokio::sync::RwLock; use tokio::sync::RwLock;
use crate::{ use crate::{
service::{pdu::gen_event_id_canonical_json, server_is_ours}, service::{pdu::gen_event_id_canonical_json, server_is_ours, Services},
services, Error, Result, Ruma, Error, Result, Ruma,
}; };
/// # `PUT /_matrix/federation/v1/send_leave/{roomId}/{eventId}` /// # `PUT /_matrix/federation/v1/send_leave/{roomId}/{eventId}`
/// ///
/// Submits a signed leave event. /// Submits a signed leave event.
pub(crate) async fn create_leave_event_v1_route( pub(crate) async fn create_leave_event_v1_route(
body: Ruma<create_leave_event::v1::Request>, State(services): State<crate::State>, body: Ruma<create_leave_event::v1::Request>,
) -> Result<create_leave_event::v1::Response> { ) -> Result<create_leave_event::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
create_leave_event(origin, &body.room_id, &body.pdu).await?; create_leave_event(services, origin, &body.room_id, &body.pdu).await?;
Ok(create_leave_event::v1::Response::new()) Ok(create_leave_event::v1::Response::new())
} }
@ -35,28 +36,30 @@ pub(crate) async fn create_leave_event_v1_route(
/// ///
/// Submits a signed leave event. /// Submits a signed leave event.
pub(crate) async fn create_leave_event_v2_route( pub(crate) async fn create_leave_event_v2_route(
body: Ruma<create_leave_event::v2::Request>, State(services): State<crate::State>, body: Ruma<create_leave_event::v2::Request>,
) -> Result<create_leave_event::v2::Response> { ) -> Result<create_leave_event::v2::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
create_leave_event(origin, &body.room_id, &body.pdu).await?; create_leave_event(services, origin, &body.room_id, &body.pdu).await?;
Ok(create_leave_event::v2::Response::new()) Ok(create_leave_event::v2::Response::new())
} }
async fn create_leave_event(origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue) -> Result<()> { async fn create_leave_event(
if !services().rooms.metadata.exists(room_id)? { services: &Services, origin: &ServerName, room_id: &RoomId, pdu: &RawJsonValue,
) -> Result<()> {
if !services.rooms.metadata.exists(room_id)? {
return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server."));
} }
// ACL check origin // ACL check origin
services().rooms.event_handler.acl_check(origin, room_id)?; services.rooms.event_handler.acl_check(origin, room_id)?;
let pub_key_map = RwLock::new(BTreeMap::new()); let pub_key_map = RwLock::new(BTreeMap::new());
// We do not add the event_id field to the pdu here because of signature and // We do not add the event_id field to the pdu here because of signature and
// hashes checks // hashes checks
let room_version_id = services().rooms.state.get_room_version(room_id)?; let room_version_id = services.rooms.state.get_room_version(room_id)?;
let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else { let Ok((event_id, value)) = gen_event_id_canonical_json(pdu, &room_version_id) else {
// Event could not be converted to canonical json // Event could not be converted to canonical json
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -107,7 +110,7 @@ async fn create_leave_event(origin: &ServerName, room_id: &RoomId, pdu: &RawJson
) )
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "User ID in sender is invalid."))?; .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "User ID in sender is invalid."))?;
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(sender.server_name(), room_id)?; .acl_check(sender.server_name(), room_id)?;
@ -145,19 +148,19 @@ async fn create_leave_event(origin: &ServerName, room_id: &RoomId, pdu: &RawJson
) )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "origin is not a server name."))?;
services() services
.rooms .rooms
.event_handler .event_handler
.fetch_required_signing_keys([&value], &pub_key_map) .fetch_required_signing_keys([&value], &pub_key_map)
.await?; .await?;
let mutex_lock = services() let mutex_lock = services
.rooms .rooms
.event_handler .event_handler
.mutex_federation .mutex_federation
.lock(room_id) .lock(room_id)
.await; .await;
let pdu_id: Vec<u8> = services() let pdu_id: Vec<u8> = services
.rooms .rooms
.event_handler .event_handler
.handle_incoming_pdu(&origin, room_id, &event_id, value, true, &pub_key_map) .handle_incoming_pdu(&origin, room_id, &event_id, value, true, &pub_key_map)
@ -166,14 +169,14 @@ async fn create_leave_event(origin: &ServerName, room_id: &RoomId, pdu: &RawJson
drop(mutex_lock); drop(mutex_lock);
let servers = services() let servers = services
.rooms .rooms
.state_cache .state_cache
.room_servers(room_id) .room_servers(room_id)
.filter_map(Result::ok) .filter_map(Result::ok)
.filter(|server| !server_is_ours(server)); .filter(|server| !server_is_ours(server));
services().sending.send_pdu_servers(servers, &pdu_id)?; services.sending.send_pdu_servers(servers, &pdu_id)?;
Ok(()) Ok(())
} }

View file

@ -1,8 +1,9 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::State;
use conduit::{Error, Result}; use conduit::{Error, Result};
use ruma::api::{client::error::ErrorKind, federation::event::get_room_state}; use ruma::api::{client::error::ErrorKind, federation::event::get_room_state};
use service::{sending::convert_to_outgoing_federation_event, services}; use service::sending::convert_to_outgoing_federation_event;
use crate::Ruma; use crate::Ruma;
@ -10,20 +11,20 @@ use crate::Ruma;
/// ///
/// Retrieves a snapshot of a room's state at a given event. /// Retrieves a snapshot of a room's state at a given event.
pub(crate) async fn get_room_state_route( pub(crate) async fn get_room_state_route(
body: Ruma<get_room_state::v1::Request>, State(services): State<crate::State>, body: Ruma<get_room_state::v1::Request>,
) -> Result<get_room_state::v1::Response> { ) -> Result<get_room_state::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)?;
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&body.room_id)? .is_world_readable(&body.room_id)?
&& !services() && !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(origin, &body.room_id)? .server_in_room(origin, &body.room_id)?
@ -31,31 +32,22 @@ pub(crate) async fn get_room_state_route(
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room."));
} }
let shortstatehash = services() let shortstatehash = services
.rooms .rooms
.state_accessor .state_accessor
.pdu_shortstatehash(&body.event_id)? .pdu_shortstatehash(&body.event_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?;
let pdus = services() let pdus = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
.await? .await?
.into_values() .into_values()
.map(|id| { .map(|id| convert_to_outgoing_federation_event(services.rooms.timeline.get_pdu_json(&id).unwrap().unwrap()))
convert_to_outgoing_federation_event(
services()
.rooms
.timeline
.get_pdu_json(&id)
.unwrap()
.unwrap(),
)
})
.collect(); .collect();
let auth_chain_ids = services() let auth_chain_ids = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])
@ -64,7 +56,7 @@ pub(crate) async fn get_room_state_route(
Ok(get_room_state::v1::Response { Ok(get_room_state::v1::Response {
auth_chain: auth_chain_ids auth_chain: auth_chain_ids
.filter_map(|id| { .filter_map(|id| {
services() services
.rooms .rooms
.timeline .timeline
.get_pdu_json(&id) .get_pdu_json(&id)

View file

@ -1,28 +1,29 @@
use std::sync::Arc; use std::sync::Arc;
use axum::extract::State;
use ruma::api::{client::error::ErrorKind, federation::event::get_room_state_ids}; use ruma::api::{client::error::ErrorKind, federation::event::get_room_state_ids};
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `GET /_matrix/federation/v1/state_ids/{roomId}` /// # `GET /_matrix/federation/v1/state_ids/{roomId}`
/// ///
/// Retrieves a snapshot of a room's state at a given event, in the form of /// Retrieves a snapshot of a room's state at a given event, in the form of
/// event IDs. /// event IDs.
pub(crate) async fn get_room_state_ids_route( pub(crate) async fn get_room_state_ids_route(
body: Ruma<get_room_state_ids::v1::Request>, State(services): State<crate::State>, body: Ruma<get_room_state_ids::v1::Request>,
) -> Result<get_room_state_ids::v1::Response> { ) -> Result<get_room_state_ids::v1::Response> {
let origin = body.origin.as_ref().expect("server is authenticated"); let origin = body.origin.as_ref().expect("server is authenticated");
services() services
.rooms .rooms
.event_handler .event_handler
.acl_check(origin, &body.room_id)?; .acl_check(origin, &body.room_id)?;
if !services() if !services
.rooms .rooms
.state_accessor .state_accessor
.is_world_readable(&body.room_id)? .is_world_readable(&body.room_id)?
&& !services() && !services
.rooms .rooms
.state_cache .state_cache
.server_in_room(origin, &body.room_id)? .server_in_room(origin, &body.room_id)?
@ -30,13 +31,13 @@ pub(crate) async fn get_room_state_ids_route(
return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room.")); return Err(Error::BadRequest(ErrorKind::forbidden(), "Server is not in room."));
} }
let shortstatehash = services() let shortstatehash = services
.rooms .rooms
.state_accessor .state_accessor
.pdu_shortstatehash(&body.event_id)? .pdu_shortstatehash(&body.event_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; .ok_or_else(|| Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?;
let pdu_ids = services() let pdu_ids = services
.rooms .rooms
.state_accessor .state_accessor
.state_full_ids(shortstatehash) .state_full_ids(shortstatehash)
@ -45,7 +46,7 @@ pub(crate) async fn get_room_state_ids_route(
.map(|id| (*id).to_owned()) .map(|id| (*id).to_owned())
.collect(); .collect();
let auth_chain_ids = services() let auth_chain_ids = services
.rooms .rooms
.auth_chain .auth_chain
.event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)]) .event_ids_iter(&body.room_id, vec![Arc::from(&*body.event_id)])

View file

@ -1,3 +1,4 @@
use axum::extract::State;
use ruma::api::{ use ruma::api::{
client::error::ErrorKind, client::error::ErrorKind,
federation::{ federation::{
@ -9,13 +10,15 @@ use ruma::api::{
use crate::{ use crate::{
client::{claim_keys_helper, get_keys_helper}, client::{claim_keys_helper, get_keys_helper},
service::user_is_local, service::user_is_local,
services, Error, Result, Ruma, Error, Result, Ruma,
}; };
/// # `GET /_matrix/federation/v1/user/devices/{userId}` /// # `GET /_matrix/federation/v1/user/devices/{userId}`
/// ///
/// Gets information on all devices of the user. /// Gets information on all devices of the user.
pub(crate) async fn get_devices_route(body: Ruma<get_devices::v1::Request>) -> Result<get_devices::v1::Response> { pub(crate) async fn get_devices_route(
State(services): State<crate::State>, body: Ruma<get_devices::v1::Request>,
) -> Result<get_devices::v1::Response> {
if !user_is_local(&body.user_id) { if !user_is_local(&body.user_id) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -27,25 +30,25 @@ pub(crate) async fn get_devices_route(body: Ruma<get_devices::v1::Request>) -> R
Ok(get_devices::v1::Response { Ok(get_devices::v1::Response {
user_id: body.user_id.clone(), user_id: body.user_id.clone(),
stream_id: services() stream_id: services
.users .users
.get_devicelist_version(&body.user_id)? .get_devicelist_version(&body.user_id)?
.unwrap_or(0) .unwrap_or(0)
.try_into() .try_into()
.expect("version will not grow that large"), .expect("version will not grow that large"),
devices: services() devices: services
.users .users
.all_devices_metadata(&body.user_id) .all_devices_metadata(&body.user_id)
.filter_map(Result::ok) .filter_map(Result::ok)
.filter_map(|metadata| { .filter_map(|metadata| {
let device_id_string = metadata.device_id.as_str().to_owned(); let device_id_string = metadata.device_id.as_str().to_owned();
let device_display_name = if services().globals.allow_device_name_federation() { let device_display_name = if services.globals.allow_device_name_federation() {
metadata.display_name metadata.display_name
} else { } else {
Some(device_id_string) Some(device_id_string)
}; };
Some(UserDevice { Some(UserDevice {
keys: services() keys: services
.users .users
.get_device_keys(&body.user_id, &metadata.device_id) .get_device_keys(&body.user_id, &metadata.device_id)
.ok()??, .ok()??,
@ -54,10 +57,10 @@ pub(crate) async fn get_devices_route(body: Ruma<get_devices::v1::Request>) -> R
}) })
}) })
.collect(), .collect(),
master_key: services() master_key: services
.users .users
.get_master_key(None, &body.user_id, &|u| u.server_name() == origin)?, .get_master_key(None, &body.user_id, &|u| u.server_name() == origin)?,
self_signing_key: services() self_signing_key: services
.users .users
.get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)?, .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == origin)?,
}) })
@ -66,7 +69,9 @@ pub(crate) async fn get_devices_route(body: Ruma<get_devices::v1::Request>) -> R
/// # `POST /_matrix/federation/v1/user/keys/query` /// # `POST /_matrix/federation/v1/user/keys/query`
/// ///
/// Gets devices and identity keys for the given users. /// Gets devices and identity keys for the given users.
pub(crate) async fn get_keys_route(body: Ruma<get_keys::v1::Request>) -> Result<get_keys::v1::Response> { pub(crate) async fn get_keys_route(
State(services): State<crate::State>, body: Ruma<get_keys::v1::Request>,
) -> Result<get_keys::v1::Response> {
if body.device_keys.iter().any(|(u, _)| !user_is_local(u)) { if body.device_keys.iter().any(|(u, _)| !user_is_local(u)) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -75,10 +80,11 @@ pub(crate) async fn get_keys_route(body: Ruma<get_keys::v1::Request>) -> Result<
} }
let result = get_keys_helper( let result = get_keys_helper(
services,
None, None,
&body.device_keys, &body.device_keys,
|u| Some(u.server_name()) == body.origin.as_deref(), |u| Some(u.server_name()) == body.origin.as_deref(),
services().globals.allow_device_name_federation(), services.globals.allow_device_name_federation(),
) )
.await?; .await?;
@ -92,7 +98,9 @@ pub(crate) async fn get_keys_route(body: Ruma<get_keys::v1::Request>) -> Result<
/// # `POST /_matrix/federation/v1/user/keys/claim` /// # `POST /_matrix/federation/v1/user/keys/claim`
/// ///
/// Claims one-time keys. /// Claims one-time keys.
pub(crate) async fn claim_keys_route(body: Ruma<claim_keys::v1::Request>) -> Result<claim_keys::v1::Response> { pub(crate) async fn claim_keys_route(
State(services): State<crate::State>, body: Ruma<claim_keys::v1::Request>,
) -> Result<claim_keys::v1::Response> {
if body.one_time_keys.iter().any(|(u, _)| !user_is_local(u)) { if body.one_time_keys.iter().any(|(u, _)| !user_is_local(u)) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -100,7 +108,7 @@ pub(crate) async fn claim_keys_route(body: Ruma<claim_keys::v1::Request>) -> Res
)); ));
} }
let result = claim_keys_helper(&body.one_time_keys).await?; let result = claim_keys_helper(services, &body.one_time_keys).await?;
Ok(claim_keys::v1::Response { Ok(claim_keys::v1::Response {
one_time_keys: result.one_time_keys, one_time_keys: result.one_time_keys,

View file

@ -1,15 +1,16 @@
use axum::extract::State;
use ruma::api::{client::error::ErrorKind, federation::discovery::discover_homeserver}; use ruma::api::{client::error::ErrorKind, federation::discovery::discover_homeserver};
use crate::{services, Error, Result, Ruma}; use crate::{Error, Result, Ruma};
/// # `GET /.well-known/matrix/server` /// # `GET /.well-known/matrix/server`
/// ///
/// Returns the .well-known URL if it is configured, otherwise returns 404. /// Returns the .well-known URL if it is configured, otherwise returns 404.
pub(crate) async fn well_known_server( pub(crate) async fn well_known_server(
_body: Ruma<discover_homeserver::Request>, State(services): State<crate::State>, _body: Ruma<discover_homeserver::Request>,
) -> Result<discover_homeserver::Response> { ) -> Result<discover_homeserver::Response> {
Ok(discover_homeserver::Response { Ok(discover_homeserver::Response {
server: match services().globals.well_known_server() { server: match services.globals.well_known_server() {
Some(server_name) => server_name.to_owned(), Some(server_name) => server_name.to_owned(),
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")),
}, },

View file

@ -176,6 +176,7 @@ impl Service {
self.db.watch(user_id, device_id).await self.db.watch(user_id, device_id).await
} }
#[inline]
pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() }
pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events } pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events }