fix: some compile time errors

Only 174 errors left!
This commit is contained in:
Timo Kösters 2022-09-06 23:15:09 +02:00 committed by Nyaaori
parent 82e7f57b38
commit 057f8364cc
No known key found for this signature in database
GPG key ID: E7819C3ED4D1F82E
118 changed files with 2139 additions and 2433 deletions

6
Cargo.lock generated
View file

@ -98,9 +98,9 @@ dependencies = [
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.56" version = "0.1.57"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96cf8829f67d2eab0b2dfa42c5d0ef737e0724e4a82b01b3e292456202b19716" checksum = "76464446b8bc32758d7e88ee1a804d9914cd9b1cb264c029899680b0be29826f"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -408,6 +408,7 @@ checksum = "3d7b894f5411737b7867f4827955924d7c254fc9f4d91a6aad6b097804b1018b"
name = "conduit" name = "conduit"
version = "0.3.0-next" version = "0.3.0-next"
dependencies = [ dependencies = [
"async-trait",
"axum", "axum",
"axum-server", "axum-server",
"base64 0.13.0", "base64 0.13.0",
@ -422,6 +423,7 @@ dependencies = [
"http", "http",
"image", "image",
"jsonwebtoken", "jsonwebtoken",
"lazy_static",
"lru-cache", "lru-cache",
"num_cpus", "num_cpus",
"opentelemetry", "opentelemetry",

View file

@ -90,6 +90,8 @@ figment = { version = "0.10.6", features = ["env", "toml"] }
tikv-jemalloc-ctl = { version = "0.4.2", features = ["use_std"], optional = true } tikv-jemalloc-ctl = { version = "0.4.2", features = ["use_std"], optional = true }
tikv-jemallocator = { version = "0.4.1", features = ["unprefixed_malloc_on_supported_platforms"], optional = true } tikv-jemallocator = { version = "0.4.1", features = ["unprefixed_malloc_on_supported_platforms"], optional = true }
lazy_static = "1.4.0"
async-trait = "0.1.57"
[features] [features]
default = ["conduit_bin", "backend_sqlite", "backend_rocksdb", "jemalloc"] default = ["conduit_bin", "backend_sqlite", "backend_rocksdb", "jemalloc"]

View file

@ -1,12 +1,11 @@
use crate::{utils, Error, Result}; use crate::{utils, Error, Result, services};
use bytes::BytesMut; use bytes::BytesMut;
use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; use ruma::api::{IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
use std::{fmt::Debug, mem, time::Duration}; use std::{fmt::Debug, mem, time::Duration};
use tracing::warn; use tracing::warn;
#[tracing::instrument(skip(globals, request))] #[tracing::instrument(skip(request))]
pub(crate) async fn send_request<T: OutgoingRequest>( pub(crate) async fn send_request<T: OutgoingRequest>(
globals: &crate::database::globals::Globals,
registration: serde_yaml::Value, registration: serde_yaml::Value,
request: T, request: T,
) -> Result<T::IncomingResponse> ) -> Result<T::IncomingResponse>
@ -46,7 +45,7 @@ where
*reqwest_request.timeout_mut() = Some(Duration::from_secs(30)); *reqwest_request.timeout_mut() = Some(Duration::from_secs(30));
let url = reqwest_request.url().clone(); let url = reqwest_request.url().clone();
let mut response = globals.default_client().execute(reqwest_request).await?; let mut response = services().globals.default_client().execute(reqwest_request).await?;
// reqwest::Response -> http::Response conversion // reqwest::Response -> http::Response conversion
let status = response.status(); let status = response.status();

View file

@ -2,9 +2,7 @@ use std::sync::Arc;
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
use crate::{ use crate::{
database::{admin::make_user_admin, DatabaseGuard}, utils, Error, Result, Ruma, services,
pdu::PduBuilder,
utils, Database, Error, Result, Ruma,
}; };
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -42,15 +40,14 @@ const RANDOM_USER_ID_LENGTH: usize = 10;
/// ///
/// Note: This will not reserve the username, so the username might become invalid when trying to register /// Note: This will not reserve the username, so the username might become invalid when trying to register
pub async fn get_register_available_route( pub async fn get_register_available_route(
db: DatabaseGuard,
body: Ruma<get_username_availability::v3::IncomingRequest>, body: Ruma<get_username_availability::v3::IncomingRequest>,
) -> Result<get_username_availability::v3::Response> { ) -> Result<get_username_availability::v3::Response> {
// Validate user id // Validate user id
let user_id = let user_id =
UserId::parse_with_server_name(body.username.to_lowercase(), db.globals.server_name()) UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name())
.ok() .ok()
.filter(|user_id| { .filter(|user_id| {
!user_id.is_historical() && user_id.server_name() == db.globals.server_name() !user_id.is_historical() && user_id.server_name() == services().globals.server_name()
}) })
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
ErrorKind::InvalidUsername, ErrorKind::InvalidUsername,
@ -58,7 +55,7 @@ pub async fn get_register_available_route(
))?; ))?;
// Check if username is creative enough // Check if username is creative enough
if db.users.exists(&user_id)? { if services().users.exists(&user_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::UserInUse, ErrorKind::UserInUse,
"Desired user ID is already taken.", "Desired user ID is already taken.",
@ -85,10 +82,9 @@ pub async fn get_register_available_route(
/// - Creates a new account and populates it with default account data /// - Creates a new account and populates it with default account data
/// - If `inhibit_login` is false: Creates a device and returns device id and access_token /// - If `inhibit_login` is false: Creates a device and returns device id and access_token
pub async fn register_route( pub async fn register_route(
db: DatabaseGuard,
body: Ruma<register::v3::IncomingRequest>, body: Ruma<register::v3::IncomingRequest>,
) -> Result<register::v3::Response> { ) -> Result<register::v3::Response> {
if !db.globals.allow_registration() && !body.from_appservice { if !services().globals.allow_registration() && !body.from_appservice {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"Registration has been disabled.", "Registration has been disabled.",
@ -100,17 +96,17 @@ pub async fn register_route(
let user_id = match (&body.username, is_guest) { let user_id = match (&body.username, is_guest) {
(Some(username), false) => { (Some(username), false) => {
let proposed_user_id = let proposed_user_id =
UserId::parse_with_server_name(username.to_lowercase(), db.globals.server_name()) UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name())
.ok() .ok()
.filter(|user_id| { .filter(|user_id| {
!user_id.is_historical() !user_id.is_historical()
&& user_id.server_name() == db.globals.server_name() && user_id.server_name() == services().globals.server_name()
}) })
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
ErrorKind::InvalidUsername, ErrorKind::InvalidUsername,
"Username is invalid.", "Username is invalid.",
))?; ))?;
if db.users.exists(&proposed_user_id)? { if services().users.exists(&proposed_user_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::UserInUse, ErrorKind::UserInUse,
"Desired user ID is already taken.", "Desired user ID is already taken.",
@ -121,10 +117,10 @@ pub async fn register_route(
_ => loop { _ => loop {
let proposed_user_id = UserId::parse_with_server_name( let proposed_user_id = UserId::parse_with_server_name(
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
db.globals.server_name(), services().globals.server_name(),
) )
.unwrap(); .unwrap();
if !db.users.exists(&proposed_user_id)? { if !services().users.exists(&proposed_user_id)? {
break proposed_user_id; break proposed_user_id;
} }
}, },
@ -143,14 +139,12 @@ pub async fn register_route(
if !body.from_appservice { if !body.from_appservice {
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = db.uiaa.try_auth( let (worked, uiaainfo) = services().uiaa.try_auth(
&UserId::parse_with_server_name("", db.globals.server_name()) &UserId::parse_with_server_name("", services().globals.server_name())
.expect("we know this is valid"), .expect("we know this is valid"),
"".into(), "".into(),
auth, auth,
&uiaainfo, &uiaainfo,
&db.users,
&db.globals,
)?; )?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -158,8 +152,8 @@ pub async fn register_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
db.uiaa.create( services().uiaa.create(
&UserId::parse_with_server_name("", db.globals.server_name()) &UserId::parse_with_server_name("", services().globals.server_name())
.expect("we know this is valid"), .expect("we know this is valid"),
"".into(), "".into(),
&uiaainfo, &uiaainfo,
@ -178,15 +172,15 @@ pub async fn register_route(
}; };
// Create user // Create user
db.users.create(&user_id, password)?; services().users.create(&user_id, password)?;
// Default to pretty displayname // Default to pretty displayname
let displayname = format!("{} ⚡️", user_id.localpart()); let displayname = format!("{} ⚡️", user_id.localpart());
db.users services().users
.set_displayname(&user_id, Some(displayname.clone()))?; .set_displayname(&user_id, Some(displayname.clone()))?;
// Initial account data // Initial account data
db.account_data.update( services().account_data.update(
None, None,
&user_id, &user_id,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
@ -195,7 +189,6 @@ pub async fn register_route(
global: push::Ruleset::server_default(&user_id), global: push::Ruleset::server_default(&user_id),
}, },
}, },
&db.globals,
)?; )?;
// Inhibit login does not work for guests // Inhibit login does not work for guests
@ -219,7 +212,7 @@ pub async fn register_route(
let token = utils::random_string(TOKEN_LENGTH); let token = utils::random_string(TOKEN_LENGTH);
// Create device for this account // Create device for this account
db.users.create_device( services().users.create_device(
&user_id, &user_id,
&device_id, &device_id,
&token, &token,
@ -227,7 +220,7 @@ pub async fn register_route(
)?; )?;
info!("New user {} registered on this server.", user_id); info!("New user {} registered on this server.", user_id);
db.admin services().admin
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"New user {} registered on this server.", "New user {} registered on this server.",
user_id user_id
@ -235,14 +228,12 @@ pub async fn register_route(
// If this is the first real user, grant them admin privileges // If this is the first real user, grant them admin privileges
// Note: the server user, @conduit:servername, is generated first // Note: the server user, @conduit:servername, is generated first
if db.users.count()? == 2 { if services().users.count()? == 2 {
make_user_admin(&db, &user_id, displayname).await?; services().admin.make_user_admin(&user_id, displayname).await?;
warn!("Granting {} admin privileges as the first user", user_id); warn!("Granting {} admin privileges as the first user", user_id);
} }
db.flush()?;
Ok(register::v3::Response { Ok(register::v3::Response {
access_token: Some(token), access_token: Some(token),
user_id, user_id,
@ -265,7 +256,6 @@ pub async fn register_route(
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub async fn change_password_route( pub async fn change_password_route(
db: DatabaseGuard,
body: Ruma<change_password::v3::IncomingRequest>, body: Ruma<change_password::v3::IncomingRequest>,
) -> Result<change_password::v3::Response> { ) -> Result<change_password::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -282,13 +272,11 @@ pub async fn change_password_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = db.uiaa.try_auth( let (worked, uiaainfo) = services().uiaa.try_auth(
sender_user, sender_user,
sender_device, sender_device,
auth, auth,
&uiaainfo, &uiaainfo,
&db.users,
&db.globals,
)?; )?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -296,32 +284,30 @@ pub async fn change_password_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
db.uiaa services().uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }
db.users services().users
.set_password(sender_user, Some(&body.new_password))?; .set_password(sender_user, Some(&body.new_password))?;
if body.logout_devices { if body.logout_devices {
// Logout all devices except the current one // Logout all devices except the current one
for id in db for id in services()
.users .users
.all_device_ids(sender_user) .all_device_ids(sender_user)
.filter_map(|id| id.ok()) .filter_map(|id| id.ok())
.filter(|id| id != sender_device) .filter(|id| id != sender_device)
{ {
db.users.remove_device(sender_user, &id)?; services().users.remove_device(sender_user, &id)?;
} }
} }
db.flush()?;
info!("User {} changed their password.", sender_user); info!("User {} changed their password.", sender_user);
db.admin services().admin
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"User {} changed their password.", "User {} changed their password.",
sender_user sender_user
@ -336,7 +322,6 @@ pub async fn change_password_route(
/// ///
/// Note: Also works for Application Services /// Note: Also works for Application Services
pub async fn whoami_route( pub async fn whoami_route(
db: DatabaseGuard,
body: Ruma<whoami::v3::Request>, body: Ruma<whoami::v3::Request>,
) -> Result<whoami::v3::Response> { ) -> Result<whoami::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -345,7 +330,7 @@ pub async fn whoami_route(
Ok(whoami::v3::Response { Ok(whoami::v3::Response {
user_id: sender_user.clone(), user_id: sender_user.clone(),
device_id, device_id,
is_guest: db.users.is_deactivated(&sender_user)?, is_guest: services().users.is_deactivated(&sender_user)?,
}) })
} }
@ -360,7 +345,6 @@ pub async fn whoami_route(
/// - Triggers device list updates /// - Triggers device list updates
/// - Removes ability to log in again /// - Removes ability to log in again
pub async fn deactivate_route( pub async fn deactivate_route(
db: DatabaseGuard,
body: Ruma<deactivate::v3::IncomingRequest>, body: Ruma<deactivate::v3::IncomingRequest>,
) -> Result<deactivate::v3::Response> { ) -> Result<deactivate::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -377,13 +361,11 @@ pub async fn deactivate_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = db.uiaa.try_auth( let (worked, uiaainfo) = services().uiaa.try_auth(
sender_user, sender_user,
sender_device, sender_device,
auth, auth,
&uiaainfo, &uiaainfo,
&db.users,
&db.globals,
)?; )?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -391,7 +373,7 @@ pub async fn deactivate_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
db.uiaa services().uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
@ -399,20 +381,18 @@ pub async fn deactivate_route(
} }
// Make the user leave all rooms before deactivation // Make the user leave all rooms before deactivation
db.rooms.leave_all_rooms(&sender_user, &db).await?; services().rooms.leave_all_rooms(&sender_user).await?;
// Remove devices and mark account as deactivated // Remove devices and mark account as deactivated
db.users.deactivate_account(sender_user)?; services().users.deactivate_account(sender_user)?;
info!("User {} deactivated their account.", sender_user); info!("User {} deactivated their account.", sender_user);
db.admin services().admin
.send_message(RoomMessageEventContent::notice_plain(format!( .send_message(RoomMessageEventContent::notice_plain(format!(
"User {} deactivated their account.", "User {} deactivated their account.",
sender_user sender_user
))); )));
db.flush()?;
Ok(deactivate::v3::Response { Ok(deactivate::v3::Response {
id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport,
}) })

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Database, Error, Result, Ruma}; use crate::{Error, Result, Ruma, services};
use regex::Regex; use regex::Regex;
use ruma::{ use ruma::{
api::{ api::{
@ -16,24 +16,21 @@ use ruma::{
/// ///
/// Creates a new room alias on this server. /// Creates a new room alias on this server.
pub async fn create_alias_route( pub async fn create_alias_route(
db: DatabaseGuard,
body: Ruma<create_alias::v3::IncomingRequest>, body: Ruma<create_alias::v3::IncomingRequest>,
) -> Result<create_alias::v3::Response> { ) -> Result<create_alias::v3::Response> {
if body.room_alias.server_name() != db.globals.server_name() { if body.room_alias.server_name() != services().globals.server_name() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Alias is from another server.", "Alias is from another server.",
)); ));
} }
if db.rooms.id_from_alias(&body.room_alias)?.is_some() { if services().rooms.id_from_alias(&body.room_alias)?.is_some() {
return Err(Error::Conflict("Alias already exists.")); return Err(Error::Conflict("Alias already exists."));
} }
db.rooms services().rooms
.set_alias(&body.room_alias, Some(&body.room_id), &db.globals)?; .set_alias(&body.room_alias, Some(&body.room_id))?;
db.flush()?;
Ok(create_alias::v3::Response::new()) Ok(create_alias::v3::Response::new())
} }
@ -45,22 +42,19 @@ pub async fn create_alias_route(
/// - TODO: additional access control checks /// - TODO: additional access control checks
/// - TODO: Update canonical alias event /// - TODO: Update canonical alias event
pub async fn delete_alias_route( pub async fn delete_alias_route(
db: DatabaseGuard,
body: Ruma<delete_alias::v3::IncomingRequest>, body: Ruma<delete_alias::v3::IncomingRequest>,
) -> Result<delete_alias::v3::Response> { ) -> Result<delete_alias::v3::Response> {
if body.room_alias.server_name() != db.globals.server_name() { if body.room_alias.server_name() != services().globals.server_name() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Alias is from another server.", "Alias is from another server.",
)); ));
} }
db.rooms.set_alias(&body.room_alias, None, &db.globals)?; services().rooms.set_alias(&body.room_alias, None)?;
// TODO: update alt_aliases? // TODO: update alt_aliases?
db.flush()?;
Ok(delete_alias::v3::Response::new()) Ok(delete_alias::v3::Response::new())
} }
@ -70,21 +64,18 @@ pub async fn delete_alias_route(
/// ///
/// - TODO: Suggest more servers to join via /// - TODO: Suggest more servers to join via
pub async fn get_alias_route( pub async fn get_alias_route(
db: DatabaseGuard,
body: Ruma<get_alias::v3::IncomingRequest>, body: Ruma<get_alias::v3::IncomingRequest>,
) -> Result<get_alias::v3::Response> { ) -> Result<get_alias::v3::Response> {
get_alias_helper(&db, &body.room_alias).await get_alias_helper(&body.room_alias).await
} }
pub(crate) async fn get_alias_helper( pub(crate) async fn get_alias_helper(
db: &Database,
room_alias: &RoomAliasId, room_alias: &RoomAliasId,
) -> Result<get_alias::v3::Response> { ) -> Result<get_alias::v3::Response> {
if room_alias.server_name() != db.globals.server_name() { if room_alias.server_name() != services().globals.server_name() {
let response = db let response = services()
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals,
room_alias.server_name(), room_alias.server_name(),
federation::query::get_room_information::v1::Request { room_alias }, federation::query::get_room_information::v1::Request { room_alias },
) )
@ -97,10 +88,10 @@ pub(crate) async fn get_alias_helper(
} }
let mut room_id = None; let mut room_id = None;
match db.rooms.id_from_alias(room_alias)? { match services().rooms.id_from_alias(room_alias)? {
Some(r) => room_id = Some(r), Some(r) => room_id = Some(r),
None => { None => {
for (_id, registration) in db.appservice.all()? { for (_id, registration) in services().appservice.all()? {
let aliases = registration let aliases = registration
.get("namespaces") .get("namespaces")
.and_then(|ns| ns.get("aliases")) .and_then(|ns| ns.get("aliases"))
@ -115,17 +106,16 @@ pub(crate) async fn get_alias_helper(
if aliases if aliases
.iter() .iter()
.any(|aliases| aliases.is_match(room_alias.as_str())) .any(|aliases| aliases.is_match(room_alias.as_str()))
&& db && services()
.sending .sending
.send_appservice_request( .send_appservice_request(
&db.globals,
registration, registration,
appservice::query::query_room_alias::v1::Request { room_alias }, appservice::query::query_room_alias::v1::Request { room_alias },
) )
.await .await
.is_ok() .is_ok()
{ {
room_id = Some(db.rooms.id_from_alias(room_alias)?.ok_or_else(|| { room_id = Some(services().rooms.id_from_alias(room_alias)?.ok_or_else(|| {
Error::bad_config("Appservice lied to us. Room does not exist.") Error::bad_config("Appservice lied to us. Room does not exist.")
})?); })?);
break; break;
@ -146,6 +136,6 @@ pub(crate) async fn get_alias_helper(
Ok(get_alias::v3::Response::new( Ok(get_alias::v3::Response::new(
room_id, room_id,
vec![db.globals.server_name().to_owned()], vec![services().globals.server_name().to_owned()],
)) ))
} }

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Error, Result, Ruma}; use crate::{Error, Result, Ruma, services};
use ruma::api::client::{ use ruma::api::client::{
backup::{ backup::{
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session,
@ -14,15 +14,12 @@ use ruma::api::client::{
/// ///
/// Creates a new backup. /// Creates a new backup.
pub async fn create_backup_version_route( pub async fn create_backup_version_route(
db: DatabaseGuard,
body: Ruma<create_backup_version::v3::Request>, body: Ruma<create_backup_version::v3::Request>,
) -> Result<create_backup_version::v3::Response> { ) -> Result<create_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let version = db let version = services()
.key_backups .key_backups
.create_backup(sender_user, &body.algorithm, &db.globals)?; .create_backup(sender_user, &body.algorithm)?;
db.flush()?;
Ok(create_backup_version::v3::Response { version }) Ok(create_backup_version::v3::Response { version })
} }
@ -31,14 +28,11 @@ pub async fn create_backup_version_route(
/// ///
/// Update information about an existing backup. Only `auth_data` can be modified. /// Update information about an existing backup. Only `auth_data` can be modified.
pub async fn update_backup_version_route( pub async fn update_backup_version_route(
db: DatabaseGuard,
body: Ruma<update_backup_version::v3::IncomingRequest>, body: Ruma<update_backup_version::v3::IncomingRequest>,
) -> Result<update_backup_version::v3::Response> { ) -> Result<update_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
db.key_backups services().key_backups
.update_backup(sender_user, &body.version, &body.algorithm, &db.globals)?; .update_backup(sender_user, &body.version, &body.algorithm)?;
db.flush()?;
Ok(update_backup_version::v3::Response {}) Ok(update_backup_version::v3::Response {})
} }
@ -47,13 +41,12 @@ pub async fn update_backup_version_route(
/// ///
/// Get information about the latest backup version. /// Get information about the latest backup version.
pub async fn get_latest_backup_info_route( pub async fn get_latest_backup_info_route(
db: DatabaseGuard,
body: Ruma<get_latest_backup_info::v3::Request>, body: Ruma<get_latest_backup_info::v3::Request>,
) -> Result<get_latest_backup_info::v3::Response> { ) -> Result<get_latest_backup_info::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let (version, algorithm) = let (version, algorithm) =
db.key_backups services().key_backups
.get_latest_backup(sender_user)? .get_latest_backup(sender_user)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
ErrorKind::NotFound, ErrorKind::NotFound,
@ -62,8 +55,8 @@ pub async fn get_latest_backup_info_route(
Ok(get_latest_backup_info::v3::Response { Ok(get_latest_backup_info::v3::Response {
algorithm, algorithm,
count: (db.key_backups.count_keys(sender_user, &version)? as u32).into(), count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(),
etag: db.key_backups.get_etag(sender_user, &version)?, etag: services().key_backups.get_etag(sender_user, &version)?,
version, version,
}) })
} }
@ -72,11 +65,10 @@ pub async fn get_latest_backup_info_route(
/// ///
/// Get information about an existing backup. /// Get information about an existing backup.
pub async fn get_backup_info_route( pub async fn get_backup_info_route(
db: DatabaseGuard,
body: Ruma<get_backup_info::v3::IncomingRequest>, body: Ruma<get_backup_info::v3::IncomingRequest>,
) -> Result<get_backup_info::v3::Response> { ) -> Result<get_backup_info::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let algorithm = db let algorithm = services()
.key_backups .key_backups
.get_backup(sender_user, &body.version)? .get_backup(sender_user, &body.version)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
@ -86,8 +78,8 @@ pub async fn get_backup_info_route(
Ok(get_backup_info::v3::Response { Ok(get_backup_info::v3::Response {
algorithm, algorithm,
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
etag: db.key_backups.get_etag(sender_user, &body.version)?, etag: services().key_backups.get_etag(sender_user, &body.version)?,
version: body.version.to_owned(), version: body.version.to_owned(),
}) })
} }
@ -98,14 +90,11 @@ pub async fn get_backup_info_route(
/// ///
/// - Deletes both information about the backup, as well as all key data related to the backup /// - Deletes both information about the backup, as well as all key data related to the backup
pub async fn delete_backup_version_route( pub async fn delete_backup_version_route(
db: DatabaseGuard,
body: Ruma<delete_backup_version::v3::IncomingRequest>, body: Ruma<delete_backup_version::v3::IncomingRequest>,
) -> Result<delete_backup_version::v3::Response> { ) -> Result<delete_backup_version::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
db.key_backups.delete_backup(sender_user, &body.version)?; services().key_backups.delete_backup(sender_user, &body.version)?;
db.flush()?;
Ok(delete_backup_version::v3::Response {}) Ok(delete_backup_version::v3::Response {})
} }
@ -118,13 +107,12 @@ pub async fn delete_backup_version_route(
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub async fn add_backup_keys_route( pub async fn add_backup_keys_route(
db: DatabaseGuard,
body: Ruma<add_backup_keys::v3::IncomingRequest>, body: Ruma<add_backup_keys::v3::IncomingRequest>,
) -> Result<add_backup_keys::v3::Response> { ) -> Result<add_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if Some(&body.version)
!= db != services()
.key_backups .key_backups
.get_latest_backup_version(sender_user)? .get_latest_backup_version(sender_user)?
.as_ref() .as_ref()
@ -137,22 +125,19 @@ pub async fn add_backup_keys_route(
for (room_id, room) in &body.rooms { for (room_id, room) in &body.rooms {
for (session_id, key_data) in &room.sessions { for (session_id, key_data) in &room.sessions {
db.key_backups.add_key( services().key_backups.add_key(
sender_user, sender_user,
&body.version, &body.version,
room_id, room_id,
session_id, session_id,
key_data, key_data,
&db.globals,
)? )?
} }
} }
db.flush()?;
Ok(add_backup_keys::v3::Response { Ok(add_backup_keys::v3::Response {
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
etag: db.key_backups.get_etag(sender_user, &body.version)?, etag: services().key_backups.get_etag(sender_user, &body.version)?,
}) })
} }
@ -164,13 +149,12 @@ pub async fn add_backup_keys_route(
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub async fn add_backup_keys_for_room_route( pub async fn add_backup_keys_for_room_route(
db: DatabaseGuard,
body: Ruma<add_backup_keys_for_room::v3::IncomingRequest>, body: Ruma<add_backup_keys_for_room::v3::IncomingRequest>,
) -> Result<add_backup_keys_for_room::v3::Response> { ) -> Result<add_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if Some(&body.version)
!= db != services()
.key_backups .key_backups
.get_latest_backup_version(sender_user)? .get_latest_backup_version(sender_user)?
.as_ref() .as_ref()
@ -182,21 +166,18 @@ pub async fn add_backup_keys_for_room_route(
} }
for (session_id, key_data) in &body.sessions { for (session_id, key_data) in &body.sessions {
db.key_backups.add_key( services().key_backups.add_key(
sender_user, sender_user,
&body.version, &body.version,
&body.room_id, &body.room_id,
session_id, session_id,
key_data, key_data,
&db.globals,
)? )?
} }
db.flush()?;
Ok(add_backup_keys_for_room::v3::Response { Ok(add_backup_keys_for_room::v3::Response {
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
etag: db.key_backups.get_etag(sender_user, &body.version)?, etag: services().key_backups.get_etag(sender_user, &body.version)?,
}) })
} }
@ -208,13 +189,12 @@ pub async fn add_backup_keys_for_room_route(
/// - Adds the keys to the backup /// - Adds the keys to the backup
/// - Returns the new number of keys in this backup and the etag /// - Returns the new number of keys in this backup and the etag
pub async fn add_backup_keys_for_session_route( pub async fn add_backup_keys_for_session_route(
db: DatabaseGuard,
body: Ruma<add_backup_keys_for_session::v3::IncomingRequest>, body: Ruma<add_backup_keys_for_session::v3::IncomingRequest>,
) -> Result<add_backup_keys_for_session::v3::Response> { ) -> Result<add_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if Some(&body.version) if Some(&body.version)
!= db != services()
.key_backups .key_backups
.get_latest_backup_version(sender_user)? .get_latest_backup_version(sender_user)?
.as_ref() .as_ref()
@ -225,20 +205,17 @@ pub async fn add_backup_keys_for_session_route(
)); ));
} }
db.key_backups.add_key( services().key_backups.add_key(
sender_user, sender_user,
&body.version, &body.version,
&body.room_id, &body.room_id,
&body.session_id, &body.session_id,
&body.session_data, &body.session_data,
&db.globals,
)?; )?;
db.flush()?;
Ok(add_backup_keys_for_session::v3::Response { Ok(add_backup_keys_for_session::v3::Response {
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
etag: db.key_backups.get_etag(sender_user, &body.version)?, etag: services().key_backups.get_etag(sender_user, &body.version)?,
}) })
} }
@ -246,12 +223,11 @@ pub async fn add_backup_keys_for_session_route(
/// ///
/// Retrieves all keys from the backup. /// Retrieves all keys from the backup.
pub async fn get_backup_keys_route( pub async fn get_backup_keys_route(
db: DatabaseGuard,
body: Ruma<get_backup_keys::v3::IncomingRequest>, body: Ruma<get_backup_keys::v3::IncomingRequest>,
) -> Result<get_backup_keys::v3::Response> { ) -> Result<get_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let rooms = db.key_backups.get_all(sender_user, &body.version)?; let rooms = services().key_backups.get_all(sender_user, &body.version)?;
Ok(get_backup_keys::v3::Response { rooms }) Ok(get_backup_keys::v3::Response { rooms })
} }
@ -260,12 +236,11 @@ pub async fn get_backup_keys_route(
/// ///
/// Retrieves all keys from the backup for a given room. /// Retrieves all keys from the backup for a given room.
pub async fn get_backup_keys_for_room_route( pub async fn get_backup_keys_for_room_route(
db: DatabaseGuard,
body: Ruma<get_backup_keys_for_room::v3::IncomingRequest>, body: Ruma<get_backup_keys_for_room::v3::IncomingRequest>,
) -> Result<get_backup_keys_for_room::v3::Response> { ) -> Result<get_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sessions = db let sessions = services()
.key_backups .key_backups
.get_room(sender_user, &body.version, &body.room_id)?; .get_room(sender_user, &body.version, &body.room_id)?;
@ -276,12 +251,11 @@ pub async fn get_backup_keys_for_room_route(
/// ///
/// Retrieves a key from the backup. /// Retrieves a key from the backup.
pub async fn get_backup_keys_for_session_route( pub async fn get_backup_keys_for_session_route(
db: DatabaseGuard,
body: Ruma<get_backup_keys_for_session::v3::IncomingRequest>, body: Ruma<get_backup_keys_for_session::v3::IncomingRequest>,
) -> Result<get_backup_keys_for_session::v3::Response> { ) -> Result<get_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let key_data = db let key_data = services()
.key_backups .key_backups
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)? .get_session(sender_user, &body.version, &body.room_id, &body.session_id)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
@ -296,18 +270,15 @@ pub async fn get_backup_keys_for_session_route(
/// ///
/// Delete the keys from the backup. /// Delete the keys from the backup.
pub async fn delete_backup_keys_route( pub async fn delete_backup_keys_route(
db: DatabaseGuard,
body: Ruma<delete_backup_keys::v3::IncomingRequest>, body: Ruma<delete_backup_keys::v3::IncomingRequest>,
) -> Result<delete_backup_keys::v3::Response> { ) -> Result<delete_backup_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
db.key_backups.delete_all_keys(sender_user, &body.version)?; services().key_backups.delete_all_keys(sender_user, &body.version)?;
db.flush()?;
Ok(delete_backup_keys::v3::Response { Ok(delete_backup_keys::v3::Response {
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
etag: db.key_backups.get_etag(sender_user, &body.version)?, etag: services().key_backups.get_etag(sender_user, &body.version)?,
}) })
} }
@ -315,19 +286,16 @@ pub async fn delete_backup_keys_route(
/// ///
/// Delete the keys from the backup for a given room. /// Delete the keys from the backup for a given room.
pub async fn delete_backup_keys_for_room_route( pub async fn delete_backup_keys_for_room_route(
db: DatabaseGuard,
body: Ruma<delete_backup_keys_for_room::v3::IncomingRequest>, body: Ruma<delete_backup_keys_for_room::v3::IncomingRequest>,
) -> Result<delete_backup_keys_for_room::v3::Response> { ) -> Result<delete_backup_keys_for_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
db.key_backups services().key_backups
.delete_room_keys(sender_user, &body.version, &body.room_id)?; .delete_room_keys(sender_user, &body.version, &body.room_id)?;
db.flush()?;
Ok(delete_backup_keys_for_room::v3::Response { Ok(delete_backup_keys_for_room::v3::Response {
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
etag: db.key_backups.get_etag(sender_user, &body.version)?, etag: services().key_backups.get_etag(sender_user, &body.version)?,
}) })
} }
@ -335,18 +303,15 @@ pub async fn delete_backup_keys_for_room_route(
/// ///
/// Delete a key from the backup. /// Delete a key from the backup.
pub async fn delete_backup_keys_for_session_route( pub async fn delete_backup_keys_for_session_route(
db: DatabaseGuard,
body: Ruma<delete_backup_keys_for_session::v3::IncomingRequest>, body: Ruma<delete_backup_keys_for_session::v3::IncomingRequest>,
) -> Result<delete_backup_keys_for_session::v3::Response> { ) -> Result<delete_backup_keys_for_session::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
db.key_backups services().key_backups
.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; .delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?;
db.flush()?;
Ok(delete_backup_keys_for_session::v3::Response { Ok(delete_backup_keys_for_session::v3::Response {
count: (db.key_backups.count_keys(sender_user, &body.version)? as u32).into(), count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
etag: db.key_backups.get_etag(sender_user, &body.version)?, etag: services().key_backups.get_etag(sender_user, &body.version)?,
}) })
} }

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Result, Ruma}; use crate::{Result, Ruma, services};
use ruma::api::client::discovery::get_capabilities::{ use ruma::api::client::discovery::get_capabilities::{
self, Capabilities, RoomVersionStability, RoomVersionsCapability, self, Capabilities, RoomVersionStability, RoomVersionsCapability,
}; };
@ -8,26 +8,25 @@ use std::collections::BTreeMap;
/// ///
/// Get information on the supported feature set and other relevent capabilities of this server. /// Get information on the supported feature set and other relevent capabilities of this server.
pub async fn get_capabilities_route( pub async fn get_capabilities_route(
db: DatabaseGuard,
_body: Ruma<get_capabilities::v3::IncomingRequest>, _body: Ruma<get_capabilities::v3::IncomingRequest>,
) -> Result<get_capabilities::v3::Response> { ) -> Result<get_capabilities::v3::Response> {
let mut available = BTreeMap::new(); let mut available = BTreeMap::new();
if db.globals.allow_unstable_room_versions() { if services().globals.allow_unstable_room_versions() {
for room_version in &db.globals.unstable_room_versions { for room_version in &services().globals.unstable_room_versions {
available.insert(room_version.clone(), RoomVersionStability::Stable); available.insert(room_version.clone(), RoomVersionStability::Stable);
} }
} else { } else {
for room_version in &db.globals.unstable_room_versions { for room_version in &services().globals.unstable_room_versions {
available.insert(room_version.clone(), RoomVersionStability::Unstable); available.insert(room_version.clone(), RoomVersionStability::Unstable);
} }
} }
for room_version in &db.globals.stable_room_versions { for room_version in &services().globals.stable_room_versions {
available.insert(room_version.clone(), RoomVersionStability::Stable); available.insert(room_version.clone(), RoomVersionStability::Stable);
} }
let mut capabilities = Capabilities::new(); let mut capabilities = Capabilities::new();
capabilities.room_versions = RoomVersionsCapability { capabilities.room_versions = RoomVersionsCapability {
default: db.globals.default_room_version(), default: services().globals.default_room_version(),
available, available,
}; };

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Error, Result, Ruma}; use crate::{Error, Result, Ruma, services};
use ruma::{ use ruma::{
api::client::{ api::client::{
config::{ config::{
@ -17,7 +17,6 @@ use serde_json::{json, value::RawValue as RawJsonValue};
/// ///
/// Sets some account data for the sender user. /// Sets some account data for the sender user.
pub async fn set_global_account_data_route( pub async fn set_global_account_data_route(
db: DatabaseGuard,
body: Ruma<set_global_account_data::v3::IncomingRequest>, body: Ruma<set_global_account_data::v3::IncomingRequest>,
) -> Result<set_global_account_data::v3::Response> { ) -> Result<set_global_account_data::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -27,7 +26,7 @@ pub async fn set_global_account_data_route(
let event_type = body.event_type.to_string(); let event_type = body.event_type.to_string();
db.account_data.update( services().account_data.update(
None, None,
sender_user, sender_user,
event_type.clone().into(), event_type.clone().into(),
@ -35,11 +34,8 @@ pub async fn set_global_account_data_route(
"type": event_type, "type": event_type,
"content": data, "content": data,
}), }),
&db.globals,
)?; )?;
db.flush()?;
Ok(set_global_account_data::v3::Response {}) Ok(set_global_account_data::v3::Response {})
} }
@ -47,7 +43,6 @@ pub async fn set_global_account_data_route(
/// ///
/// Sets some room account data for the sender user. /// Sets some room account data for the sender user.
pub async fn set_room_account_data_route( pub async fn set_room_account_data_route(
db: DatabaseGuard,
body: Ruma<set_room_account_data::v3::IncomingRequest>, body: Ruma<set_room_account_data::v3::IncomingRequest>,
) -> Result<set_room_account_data::v3::Response> { ) -> Result<set_room_account_data::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -57,7 +52,7 @@ pub async fn set_room_account_data_route(
let event_type = body.event_type.to_string(); let event_type = body.event_type.to_string();
db.account_data.update( services().account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
event_type.clone().into(), event_type.clone().into(),
@ -65,11 +60,8 @@ pub async fn set_room_account_data_route(
"type": event_type, "type": event_type,
"content": data, "content": data,
}), }),
&db.globals,
)?; )?;
db.flush()?;
Ok(set_room_account_data::v3::Response {}) Ok(set_room_account_data::v3::Response {})
} }
@ -77,12 +69,11 @@ pub async fn set_room_account_data_route(
/// ///
/// Gets some account data for the sender user. /// Gets some account data for the sender user.
pub async fn get_global_account_data_route( pub async fn get_global_account_data_route(
db: DatabaseGuard,
body: Ruma<get_global_account_data::v3::IncomingRequest>, body: Ruma<get_global_account_data::v3::IncomingRequest>,
) -> Result<get_global_account_data::v3::Response> { ) -> Result<get_global_account_data::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event: Box<RawJsonValue> = db let event: Box<RawJsonValue> = services()
.account_data .account_data
.get(None, sender_user, body.event_type.clone().into())? .get(None, sender_user, body.event_type.clone().into())?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?;
@ -98,12 +89,11 @@ pub async fn get_global_account_data_route(
/// ///
/// Gets some room account data for the sender user. /// Gets some room account data for the sender user.
pub async fn get_room_account_data_route( pub async fn get_room_account_data_route(
db: DatabaseGuard,
body: Ruma<get_room_account_data::v3::IncomingRequest>, body: Ruma<get_room_account_data::v3::IncomingRequest>,
) -> Result<get_room_account_data::v3::Response> { ) -> Result<get_room_account_data::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event: Box<RawJsonValue> = db let event: Box<RawJsonValue> = services()
.account_data .account_data
.get( .get(
Some(&body.room_id), Some(&body.room_id),

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Error, Result, Ruma}; use crate::{Error, Result, Ruma, services};
use ruma::{ use ruma::{
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
events::StateEventType, events::StateEventType,
@ -13,7 +13,6 @@ use tracing::error;
/// - Only works if the user is joined (TODO: always allow, but only show events if the user was /// - Only works if the user is joined (TODO: always allow, but only show events if the user was
/// joined, depending on history_visibility) /// joined, depending on history_visibility)
pub async fn get_context_route( pub async fn get_context_route(
db: DatabaseGuard,
body: Ruma<get_context::v3::IncomingRequest>, body: Ruma<get_context::v3::IncomingRequest>,
) -> Result<get_context::v3::Response> { ) -> Result<get_context::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -28,7 +27,7 @@ pub async fn get_context_route(
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
let base_pdu_id = db let base_pdu_id = services()
.rooms .rooms
.get_pdu_id(&body.event_id)? .get_pdu_id(&body.event_id)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
@ -36,9 +35,9 @@ pub async fn get_context_route(
"Base event id not found.", "Base event id not found.",
))?; ))?;
let base_token = db.rooms.pdu_count(&base_pdu_id)?; let base_token = services().rooms.pdu_count(&base_pdu_id)?;
let base_event = db let base_event = services()
.rooms .rooms
.get_pdu_from_id(&base_pdu_id)? .get_pdu_from_id(&base_pdu_id)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
@ -48,14 +47,14 @@ pub async fn get_context_route(
let room_id = base_event.room_id.clone(); let room_id = base_event.room_id.clone();
if !db.rooms.is_joined(sender_user, &room_id)? { if !services().rooms.is_joined(sender_user, &room_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
)); ));
} }
if !db.rooms.lazy_load_was_sent_before( if !services().rooms.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&room_id, &room_id,
@ -67,7 +66,7 @@ pub async fn get_context_route(
let base_event = base_event.to_room_event(); let base_event = base_event.to_room_event();
let events_before: Vec<_> = db let events_before: Vec<_> = services()
.rooms .rooms
.pdus_until(sender_user, &room_id, base_token)? .pdus_until(sender_user, &room_id, base_token)?
.take( .take(
@ -80,7 +79,7 @@ pub async fn get_context_route(
.collect(); .collect();
for (_, event) in &events_before { for (_, event) in &events_before {
if !db.rooms.lazy_load_was_sent_before( if !services().rooms.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&room_id, &room_id,
@ -93,7 +92,7 @@ pub async fn get_context_route(
let start_token = events_before let start_token = events_before
.last() .last()
.and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) .and_then(|(pdu_id, _)| services().rooms.pdu_count(pdu_id).ok())
.map(|count| count.to_string()); .map(|count| count.to_string());
let events_before: Vec<_> = events_before let events_before: Vec<_> = events_before
@ -101,7 +100,7 @@ pub async fn get_context_route(
.map(|(_, pdu)| pdu.to_room_event()) .map(|(_, pdu)| pdu.to_room_event())
.collect(); .collect();
let events_after: Vec<_> = db let events_after: Vec<_> = services()
.rooms .rooms
.pdus_after(sender_user, &room_id, base_token)? .pdus_after(sender_user, &room_id, base_token)?
.take( .take(
@ -114,7 +113,7 @@ pub async fn get_context_route(
.collect(); .collect();
for (_, event) in &events_after { for (_, event) in &events_after {
if !db.rooms.lazy_load_was_sent_before( if !services().rooms.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&room_id, &room_id,
@ -125,23 +124,23 @@ pub async fn get_context_route(
} }
} }
let shortstatehash = match db.rooms.pdu_shortstatehash( let shortstatehash = match services().rooms.pdu_shortstatehash(
events_after events_after
.last() .last()
.map_or(&*body.event_id, |(_, e)| &*e.event_id), .map_or(&*body.event_id, |(_, e)| &*e.event_id),
)? { )? {
Some(s) => s, Some(s) => s,
None => db None => services()
.rooms .rooms
.current_shortstatehash(&room_id)? .current_shortstatehash(&room_id)?
.expect("All rooms have state"), .expect("All rooms have state"),
}; };
let state_ids = db.rooms.state_full_ids(shortstatehash).await?; let state_ids = services().rooms.state_full_ids(shortstatehash).await?;
let end_token = events_after let end_token = events_after
.last() .last()
.and_then(|(pdu_id, _)| db.rooms.pdu_count(pdu_id).ok()) .and_then(|(pdu_id, _)| services().rooms.pdu_count(pdu_id).ok())
.map(|count| count.to_string()); .map(|count| count.to_string());
let events_after: Vec<_> = events_after let events_after: Vec<_> = events_after
@ -152,10 +151,10 @@ pub async fn get_context_route(
let mut state = Vec::new(); let mut state = Vec::new();
for (shortstatekey, id) in state_ids { for (shortstatekey, id) in state_ids {
let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; let (event_type, state_key) = services().rooms.get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember { if event_type != StateEventType::RoomMember {
let pdu = match db.rooms.get_pdu(&id)? { let pdu = match services().rooms.get_pdu(&id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
None => { None => {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
@ -164,7 +163,7 @@ pub async fn get_context_route(
}; };
state.push(pdu.to_state_event()); state.push(pdu.to_state_event());
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
let pdu = match db.rooms.get_pdu(&id)? { let pdu = match services().rooms.get_pdu(&id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
None => { None => {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; use crate::{utils, Error, Result, Ruma, services};
use ruma::api::client::{ use ruma::api::client::{
device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
error::ErrorKind, error::ErrorKind,
@ -11,12 +11,11 @@ use super::SESSION_ID_LENGTH;
/// ///
/// Get metadata on all devices of the sender user. /// Get metadata on all devices of the sender user.
pub async fn get_devices_route( pub async fn get_devices_route(
db: DatabaseGuard,
body: Ruma<get_devices::v3::Request>, body: Ruma<get_devices::v3::Request>,
) -> Result<get_devices::v3::Response> { ) -> Result<get_devices::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let devices: Vec<device::Device> = db let devices: Vec<device::Device> = services()
.users .users
.all_devices_metadata(sender_user) .all_devices_metadata(sender_user)
.filter_map(|r| r.ok()) // Filter out buggy devices .filter_map(|r| r.ok()) // Filter out buggy devices
@ -29,12 +28,11 @@ pub async fn get_devices_route(
/// ///
/// Get metadata on a single device of the sender user. /// Get metadata on a single device of the sender user.
pub async fn get_device_route( pub async fn get_device_route(
db: DatabaseGuard,
body: Ruma<get_device::v3::IncomingRequest>, body: Ruma<get_device::v3::IncomingRequest>,
) -> Result<get_device::v3::Response> { ) -> Result<get_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let device = db let device = services()
.users .users
.get_device_metadata(sender_user, &body.body.device_id)? .get_device_metadata(sender_user, &body.body.device_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
@ -46,23 +44,20 @@ pub async fn get_device_route(
/// ///
/// Updates the metadata on a given device of the sender user. /// Updates the metadata on a given device of the sender user.
pub async fn update_device_route( pub async fn update_device_route(
db: DatabaseGuard,
body: Ruma<update_device::v3::IncomingRequest>, body: Ruma<update_device::v3::IncomingRequest>,
) -> Result<update_device::v3::Response> { ) -> Result<update_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mut device = db let mut device = services()
.users .users
.get_device_metadata(sender_user, &body.device_id)? .get_device_metadata(sender_user, &body.device_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
device.display_name = body.display_name.clone(); device.display_name = body.display_name.clone();
db.users services().users
.update_device_metadata(sender_user, &body.device_id, &device)?; .update_device_metadata(sender_user, &body.device_id, &device)?;
db.flush()?;
Ok(update_device::v3::Response {}) Ok(update_device::v3::Response {})
} }
@ -76,7 +71,6 @@ pub async fn update_device_route(
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub async fn delete_device_route( pub async fn delete_device_route(
db: DatabaseGuard,
body: Ruma<delete_device::v3::IncomingRequest>, body: Ruma<delete_device::v3::IncomingRequest>,
) -> Result<delete_device::v3::Response> { ) -> Result<delete_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -94,13 +88,11 @@ pub async fn delete_device_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = db.uiaa.try_auth( let (worked, uiaainfo) = services().uiaa.try_auth(
sender_user, sender_user,
sender_device, sender_device,
auth, auth,
&uiaainfo, &uiaainfo,
&db.users,
&db.globals,
)?; )?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -108,16 +100,14 @@ pub async fn delete_device_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
db.uiaa services().uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
} }
db.users.remove_device(sender_user, &body.device_id)?; services().users.remove_device(sender_user, &body.device_id)?;
db.flush()?;
Ok(delete_device::v3::Response {}) Ok(delete_device::v3::Response {})
} }
@ -134,7 +124,6 @@ pub async fn delete_device_route(
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub async fn delete_devices_route( pub async fn delete_devices_route(
db: DatabaseGuard,
body: Ruma<delete_devices::v3::IncomingRequest>, body: Ruma<delete_devices::v3::IncomingRequest>,
) -> Result<delete_devices::v3::Response> { ) -> Result<delete_devices::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -152,13 +141,11 @@ pub async fn delete_devices_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = db.uiaa.try_auth( let (worked, uiaainfo) = services().uiaa.try_auth(
sender_user, sender_user,
sender_device, sender_device,
auth, auth,
&uiaainfo, &uiaainfo,
&db.users,
&db.globals,
)?; )?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -166,7 +153,7 @@ pub async fn delete_devices_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
db.uiaa services().uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
@ -174,10 +161,8 @@ pub async fn delete_devices_route(
} }
for device_id in &body.devices { for device_id in &body.devices {
db.users.remove_device(sender_user, device_id)? services().users.remove_device(sender_user, device_id)?
} }
db.flush()?;
Ok(delete_devices::v3::Response {}) Ok(delete_devices::v3::Response {})
} }

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Database, Error, Result, Ruma}; use crate::{Error, Result, Ruma, services};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
@ -37,11 +37,9 @@ use tracing::{info, warn};
/// ///
/// - Rooms are ordered by the number of joined members /// - Rooms are ordered by the number of joined members
pub async fn get_public_rooms_filtered_route( pub async fn get_public_rooms_filtered_route(
db: DatabaseGuard,
body: Ruma<get_public_rooms_filtered::v3::IncomingRequest>, body: Ruma<get_public_rooms_filtered::v3::IncomingRequest>,
) -> Result<get_public_rooms_filtered::v3::Response> { ) -> Result<get_public_rooms_filtered::v3::Response> {
get_public_rooms_filtered_helper( get_public_rooms_filtered_helper(
&db,
body.server.as_deref(), body.server.as_deref(),
body.limit, body.limit,
body.since.as_deref(), body.since.as_deref(),
@ -57,11 +55,9 @@ pub async fn get_public_rooms_filtered_route(
/// ///
/// - Rooms are ordered by the number of joined members /// - Rooms are ordered by the number of joined members
pub async fn get_public_rooms_route( pub async fn get_public_rooms_route(
db: DatabaseGuard,
body: Ruma<get_public_rooms::v3::IncomingRequest>, body: Ruma<get_public_rooms::v3::IncomingRequest>,
) -> Result<get_public_rooms::v3::Response> { ) -> Result<get_public_rooms::v3::Response> {
let response = get_public_rooms_filtered_helper( let response = get_public_rooms_filtered_helper(
&db,
body.server.as_deref(), body.server.as_deref(),
body.limit, body.limit,
body.since.as_deref(), body.since.as_deref(),
@ -84,17 +80,16 @@ pub async fn get_public_rooms_route(
/// ///
/// - TODO: Access control checks /// - TODO: Access control checks
pub async fn set_room_visibility_route( pub async fn set_room_visibility_route(
db: DatabaseGuard,
body: Ruma<set_room_visibility::v3::IncomingRequest>, body: Ruma<set_room_visibility::v3::IncomingRequest>,
) -> Result<set_room_visibility::v3::Response> { ) -> Result<set_room_visibility::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
match &body.visibility { match &body.visibility {
room::Visibility::Public => { room::Visibility::Public => {
db.rooms.set_public(&body.room_id, true)?; services().rooms.set_public(&body.room_id, true)?;
info!("{} made {} public", sender_user, body.room_id); info!("{} made {} public", sender_user, body.room_id);
} }
room::Visibility::Private => db.rooms.set_public(&body.room_id, false)?, room::Visibility::Private => services().rooms.set_public(&body.room_id, false)?,
_ => { _ => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
@ -103,8 +98,6 @@ pub async fn set_room_visibility_route(
} }
} }
db.flush()?;
Ok(set_room_visibility::v3::Response {}) Ok(set_room_visibility::v3::Response {})
} }
@ -112,11 +105,10 @@ pub async fn set_room_visibility_route(
/// ///
/// Gets the visibility of a given room in the room directory. /// Gets the visibility of a given room in the room directory.
pub async fn get_room_visibility_route( pub async fn get_room_visibility_route(
db: DatabaseGuard,
body: Ruma<get_room_visibility::v3::IncomingRequest>, body: Ruma<get_room_visibility::v3::IncomingRequest>,
) -> Result<get_room_visibility::v3::Response> { ) -> Result<get_room_visibility::v3::Response> {
Ok(get_room_visibility::v3::Response { Ok(get_room_visibility::v3::Response {
visibility: if db.rooms.is_public_room(&body.room_id)? { visibility: if services().rooms.is_public_room(&body.room_id)? {
room::Visibility::Public room::Visibility::Public
} else { } else {
room::Visibility::Private room::Visibility::Private
@ -125,19 +117,17 @@ pub async fn get_room_visibility_route(
} }
pub(crate) async fn get_public_rooms_filtered_helper( pub(crate) async fn get_public_rooms_filtered_helper(
db: &Database,
server: Option<&ServerName>, server: Option<&ServerName>,
limit: Option<UInt>, limit: Option<UInt>,
since: Option<&str>, since: Option<&str>,
filter: &IncomingFilter, filter: &IncomingFilter,
_network: &IncomingRoomNetwork, _network: &IncomingRoomNetwork,
) -> Result<get_public_rooms_filtered::v3::Response> { ) -> Result<get_public_rooms_filtered::v3::Response> {
if let Some(other_server) = server.filter(|server| *server != db.globals.server_name().as_str()) if let Some(other_server) = server.filter(|server| *server != services().globals.server_name().as_str())
{ {
let response = db let response = services()
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals,
other_server, other_server,
federation::directory::get_public_rooms_filtered::v1::Request { federation::directory::get_public_rooms_filtered::v1::Request {
limit, limit,
@ -184,14 +174,14 @@ pub(crate) async fn get_public_rooms_filtered_helper(
} }
} }
let mut all_rooms: Vec<_> = db let mut all_rooms: Vec<_> = services()
.rooms .rooms
.public_rooms() .public_rooms()
.map(|room_id| { .map(|room_id| {
let room_id = room_id?; let room_id = room_id?;
let chunk = PublicRoomsChunk { let chunk = PublicRoomsChunk {
canonical_alias: db canonical_alias: services()
.rooms .rooms
.room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? .room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")?
.map_or(Ok(None), |s| { .map_or(Ok(None), |s| {
@ -201,7 +191,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
Error::bad_database("Invalid canonical alias event in database.") Error::bad_database("Invalid canonical alias event in database.")
}) })
})?, })?,
name: db name: services()
.rooms .rooms
.room_state_get(&room_id, &StateEventType::RoomName, "")? .room_state_get(&room_id, &StateEventType::RoomName, "")?
.map_or(Ok(None), |s| { .map_or(Ok(None), |s| {
@ -211,7 +201,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
Error::bad_database("Invalid room name event in database.") Error::bad_database("Invalid room name event in database.")
}) })
})?, })?,
num_joined_members: db num_joined_members: services()
.rooms .rooms
.room_joined_count(&room_id)? .room_joined_count(&room_id)?
.unwrap_or_else(|| { .unwrap_or_else(|| {
@ -220,7 +210,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
}) })
.try_into() .try_into()
.expect("user count should not be that big"), .expect("user count should not be that big"),
topic: db topic: services()
.rooms .rooms
.room_state_get(&room_id, &StateEventType::RoomTopic, "")? .room_state_get(&room_id, &StateEventType::RoomTopic, "")?
.map_or(Ok(None), |s| { .map_or(Ok(None), |s| {
@ -230,7 +220,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
Error::bad_database("Invalid room topic event in database.") Error::bad_database("Invalid room topic event in database.")
}) })
})?, })?,
world_readable: db world_readable: services()
.rooms .rooms
.room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? .room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(false), |s| { .map_or(Ok(false), |s| {
@ -244,7 +234,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
) )
}) })
})?, })?,
guest_can_join: db guest_can_join: services()
.rooms .rooms
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? .room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")?
.map_or(Ok(false), |s| { .map_or(Ok(false), |s| {
@ -256,7 +246,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
Error::bad_database("Invalid room guest access event in database.") Error::bad_database("Invalid room guest access event in database.")
}) })
})?, })?,
avatar_url: db avatar_url: services()
.rooms .rooms
.room_state_get(&room_id, &StateEventType::RoomAvatar, "")? .room_state_get(&room_id, &StateEventType::RoomAvatar, "")?
.map(|s| { .map(|s| {
@ -269,7 +259,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
.transpose()? .transpose()?
// url is now an Option<String> so we must flatten // url is now an Option<String> so we must flatten
.flatten(), .flatten(),
join_rule: db join_rule: services()
.rooms .rooms
.room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| { .map(|s| {

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Error, Result, Ruma}; use crate::{Error, Result, Ruma, services};
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
filter::{create_filter, get_filter}, filter::{create_filter, get_filter},
@ -10,11 +10,10 @@ use ruma::api::client::{
/// ///
/// - A user can only access their own filters /// - A user can only access their own filters
pub async fn get_filter_route( pub async fn get_filter_route(
db: DatabaseGuard,
body: Ruma<get_filter::v3::IncomingRequest>, body: Ruma<get_filter::v3::IncomingRequest>,
) -> Result<get_filter::v3::Response> { ) -> Result<get_filter::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let filter = match db.users.get_filter(sender_user, &body.filter_id)? { let filter = match services().users.get_filter(sender_user, &body.filter_id)? {
Some(filter) => filter, Some(filter) => filter,
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")), None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")),
}; };
@ -26,11 +25,10 @@ pub async fn get_filter_route(
/// ///
/// Creates a new filter to be used by other endpoints. /// Creates a new filter to be used by other endpoints.
pub async fn create_filter_route( pub async fn create_filter_route(
db: DatabaseGuard,
body: Ruma<create_filter::v3::IncomingRequest>, body: Ruma<create_filter::v3::IncomingRequest>,
) -> Result<create_filter::v3::Response> { ) -> Result<create_filter::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(create_filter::v3::Response::new( Ok(create_filter::v3::Response::new(
db.users.create_filter(sender_user, &body.filter)?, services().users.create_filter(sender_user, &body.filter)?,
)) ))
} }

View file

@ -1,5 +1,5 @@
use super::SESSION_ID_LENGTH; use super::SESSION_ID_LENGTH;
use crate::{database::DatabaseGuard, utils, Database, Error, Result, Ruma}; use crate::{utils, Error, Result, Ruma, services};
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
use ruma::{ use ruma::{
api::{ api::{
@ -26,39 +26,34 @@ use std::collections::{BTreeMap, HashMap, HashSet};
/// - Adds one time keys /// - Adds one time keys
/// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?) /// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?)
pub async fn upload_keys_route( pub async fn upload_keys_route(
db: DatabaseGuard,
body: Ruma<upload_keys::v3::Request>, body: Ruma<upload_keys::v3::Request>,
) -> Result<upload_keys::v3::Response> { ) -> Result<upload_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
for (key_key, key_value) in &body.one_time_keys { for (key_key, key_value) in &body.one_time_keys {
db.users services().users
.add_one_time_key(sender_user, sender_device, key_key, key_value, &db.globals)?; .add_one_time_key(sender_user, sender_device, key_key, key_value)?;
} }
if let Some(device_keys) = &body.device_keys { if let Some(device_keys) = &body.device_keys {
// TODO: merge this and the existing event? // TODO: merge this and the existing event?
// This check is needed to assure that signatures are kept // This check is needed to assure that signatures are kept
if db if services()
.users .users
.get_device_keys(sender_user, sender_device)? .get_device_keys(sender_user, sender_device)?
.is_none() .is_none()
{ {
db.users.add_device_keys( services().users.add_device_keys(
sender_user, sender_user,
sender_device, sender_device,
device_keys, device_keys,
&db.rooms,
&db.globals,
)?; )?;
} }
} }
db.flush()?;
Ok(upload_keys::v3::Response { Ok(upload_keys::v3::Response {
one_time_key_counts: db.users.count_one_time_keys(sender_user, sender_device)?, one_time_key_counts: services().users.count_one_time_keys(sender_user, sender_device)?,
}) })
} }
@ -70,7 +65,6 @@ pub async fn upload_keys_route(
/// - Gets master keys, self-signing keys, user signing keys and device keys. /// - Gets master keys, self-signing keys, user signing keys and device keys.
/// - The master and self-signing keys contain signatures that the user is allowed to see /// - The master and self-signing keys contain signatures that the user is allowed to see
pub async fn get_keys_route( pub async fn get_keys_route(
db: DatabaseGuard,
body: Ruma<get_keys::v3::IncomingRequest>, body: Ruma<get_keys::v3::IncomingRequest>,
) -> Result<get_keys::v3::Response> { ) -> Result<get_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -79,7 +73,6 @@ pub async fn get_keys_route(
Some(sender_user), Some(sender_user),
&body.device_keys, &body.device_keys,
|u| u == sender_user, |u| u == sender_user,
&db,
) )
.await?; .await?;
@ -90,12 +83,9 @@ pub async fn get_keys_route(
/// ///
/// Claims one-time keys /// Claims one-time keys
pub async fn claim_keys_route( pub async fn claim_keys_route(
db: DatabaseGuard,
body: Ruma<claim_keys::v3::Request>, body: Ruma<claim_keys::v3::Request>,
) -> Result<claim_keys::v3::Response> { ) -> Result<claim_keys::v3::Response> {
let response = claim_keys_helper(&body.one_time_keys, &db).await?; let response = claim_keys_helper(&body.one_time_keys).await?;
db.flush()?;
Ok(response) Ok(response)
} }
@ -106,7 +96,6 @@ pub async fn claim_keys_route(
/// ///
/// - Requires UIAA to verify password /// - Requires UIAA to verify password
pub async fn upload_signing_keys_route( pub async fn upload_signing_keys_route(
db: DatabaseGuard,
body: Ruma<upload_signing_keys::v3::IncomingRequest>, body: Ruma<upload_signing_keys::v3::IncomingRequest>,
) -> Result<upload_signing_keys::v3::Response> { ) -> Result<upload_signing_keys::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -124,13 +113,11 @@ pub async fn upload_signing_keys_route(
}; };
if let Some(auth) = &body.auth { if let Some(auth) = &body.auth {
let (worked, uiaainfo) = db.uiaa.try_auth( let (worked, uiaainfo) = services().uiaa.try_auth(
sender_user, sender_user,
sender_device, sender_device,
auth, auth,
&uiaainfo, &uiaainfo,
&db.users,
&db.globals,
)?; )?;
if !worked { if !worked {
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
@ -138,7 +125,7 @@ pub async fn upload_signing_keys_route(
// Success! // Success!
} else if let Some(json) = body.json_body { } else if let Some(json) = body.json_body {
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
db.uiaa services().uiaa
.create(sender_user, sender_device, &uiaainfo, &json)?; .create(sender_user, sender_device, &uiaainfo, &json)?;
return Err(Error::Uiaa(uiaainfo)); return Err(Error::Uiaa(uiaainfo));
} else { } else {
@ -146,18 +133,14 @@ pub async fn upload_signing_keys_route(
} }
if let Some(master_key) = &body.master_key { if let Some(master_key) = &body.master_key {
db.users.add_cross_signing_keys( services().users.add_cross_signing_keys(
sender_user, sender_user,
master_key, master_key,
&body.self_signing_key, &body.self_signing_key,
&body.user_signing_key, &body.user_signing_key,
&db.rooms,
&db.globals,
)?; )?;
} }
db.flush()?;
Ok(upload_signing_keys::v3::Response {}) Ok(upload_signing_keys::v3::Response {})
} }
@ -165,7 +148,6 @@ pub async fn upload_signing_keys_route(
/// ///
/// Uploads end-to-end key signatures from the sender user. /// Uploads end-to-end key signatures from the sender user.
pub async fn upload_signatures_route( pub async fn upload_signatures_route(
db: DatabaseGuard,
body: Ruma<upload_signatures::v3::Request>, body: Ruma<upload_signatures::v3::Request>,
) -> Result<upload_signatures::v3::Response> { ) -> Result<upload_signatures::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -205,20 +187,16 @@ pub async fn upload_signatures_route(
))? ))?
.to_owned(), .to_owned(),
); );
db.users.sign_key( services().users.sign_key(
user_id, user_id,
key_id, key_id,
signature, signature,
sender_user, sender_user,
&db.rooms,
&db.globals,
)?; )?;
} }
} }
} }
db.flush()?;
Ok(upload_signatures::v3::Response { Ok(upload_signatures::v3::Response {
failures: BTreeMap::new(), // TODO: integrate failures: BTreeMap::new(), // TODO: integrate
}) })
@ -230,7 +208,6 @@ pub async fn upload_signatures_route(
/// ///
/// - TODO: left users /// - TODO: left users
pub async fn get_key_changes_route( pub async fn get_key_changes_route(
db: DatabaseGuard,
body: Ruma<get_key_changes::v3::IncomingRequest>, body: Ruma<get_key_changes::v3::IncomingRequest>,
) -> Result<get_key_changes::v3::Response> { ) -> Result<get_key_changes::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -238,7 +215,7 @@ pub async fn get_key_changes_route(
let mut device_list_updates = HashSet::new(); let mut device_list_updates = HashSet::new();
device_list_updates.extend( device_list_updates.extend(
db.users services().users
.keys_changed( .keys_changed(
sender_user.as_str(), sender_user.as_str(),
body.from body.from
@ -253,9 +230,9 @@ pub async fn get_key_changes_route(
.filter_map(|r| r.ok()), .filter_map(|r| r.ok()),
); );
for room_id in db.rooms.rooms_joined(sender_user).filter_map(|r| r.ok()) { for room_id in services().rooms.rooms_joined(sender_user).filter_map(|r| r.ok()) {
device_list_updates.extend( device_list_updates.extend(
db.users services().users
.keys_changed( .keys_changed(
&room_id.to_string(), &room_id.to_string(),
body.from.parse().map_err(|_| { body.from.parse().map_err(|_| {
@ -278,7 +255,6 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
sender_user: Option<&UserId>, sender_user: Option<&UserId>,
device_keys_input: &BTreeMap<Box<UserId>, Vec<Box<DeviceId>>>, device_keys_input: &BTreeMap<Box<UserId>, Vec<Box<DeviceId>>>,
allowed_signatures: F, allowed_signatures: F,
db: &Database,
) -> Result<get_keys::v3::Response> { ) -> Result<get_keys::v3::Response> {
let mut master_keys = BTreeMap::new(); let mut master_keys = BTreeMap::new();
let mut self_signing_keys = BTreeMap::new(); let mut self_signing_keys = BTreeMap::new();
@ -290,7 +266,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
for (user_id, device_ids) in device_keys_input { for (user_id, device_ids) in device_keys_input {
let user_id: &UserId = &**user_id; let user_id: &UserId = &**user_id;
if user_id.server_name() != db.globals.server_name() { if user_id.server_name() != services().globals.server_name() {
get_over_federation get_over_federation
.entry(user_id.server_name()) .entry(user_id.server_name())
.or_insert_with(Vec::new) .or_insert_with(Vec::new)
@ -300,10 +276,10 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
if device_ids.is_empty() { if device_ids.is_empty() {
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for device_id in db.users.all_device_ids(user_id) { for device_id in services().users.all_device_ids(user_id) {
let device_id = device_id?; let device_id = device_id?;
if let Some(mut keys) = db.users.get_device_keys(user_id, &device_id)? { if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? {
let metadata = db let metadata = services()
.users .users
.get_device_metadata(user_id, &device_id)? .get_device_metadata(user_id, &device_id)?
.ok_or_else(|| { .ok_or_else(|| {
@ -319,8 +295,8 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
} else { } else {
for device_id in device_ids { for device_id in device_ids {
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
if let Some(mut keys) = db.users.get_device_keys(user_id, device_id)? { if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? {
let metadata = db.users.get_device_metadata(user_id, device_id)?.ok_or( let metadata = services().users.get_device_metadata(user_id, device_id)?.ok_or(
Error::BadRequest( Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Tried to get keys for nonexistent device.", "Tried to get keys for nonexistent device.",
@ -335,17 +311,17 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
} }
} }
if let Some(master_key) = db.users.get_master_key(user_id, &allowed_signatures)? { if let Some(master_key) = services().users.get_master_key(user_id, &allowed_signatures)? {
master_keys.insert(user_id.to_owned(), master_key); master_keys.insert(user_id.to_owned(), master_key);
} }
if let Some(self_signing_key) = db if let Some(self_signing_key) = services()
.users .users
.get_self_signing_key(user_id, &allowed_signatures)? .get_self_signing_key(user_id, &allowed_signatures)?
{ {
self_signing_keys.insert(user_id.to_owned(), self_signing_key); self_signing_keys.insert(user_id.to_owned(), self_signing_key);
} }
if Some(user_id) == sender_user { if Some(user_id) == sender_user {
if let Some(user_signing_key) = db.users.get_user_signing_key(user_id)? { if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? {
user_signing_keys.insert(user_id.to_owned(), user_signing_key); user_signing_keys.insert(user_id.to_owned(), user_signing_key);
} }
} }
@ -362,9 +338,8 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
} }
( (
server, server,
db.sending services().sending
.send_federation_request( .send_federation_request(
&db.globals,
server, server,
federation::keys::get_keys::v1::Request { federation::keys::get_keys::v1::Request {
device_keys: device_keys_input_fed, device_keys: device_keys_input_fed,
@ -417,14 +392,13 @@ fn add_unsigned_device_display_name(
pub(crate) async fn claim_keys_helper( pub(crate) async fn claim_keys_helper(
one_time_keys_input: &BTreeMap<Box<UserId>, BTreeMap<Box<DeviceId>, DeviceKeyAlgorithm>>, one_time_keys_input: &BTreeMap<Box<UserId>, BTreeMap<Box<DeviceId>, DeviceKeyAlgorithm>>,
db: &Database,
) -> Result<claim_keys::v3::Response> { ) -> Result<claim_keys::v3::Response> {
let mut one_time_keys = BTreeMap::new(); let mut one_time_keys = BTreeMap::new();
let mut get_over_federation = BTreeMap::new(); let mut get_over_federation = BTreeMap::new();
for (user_id, map) in one_time_keys_input { for (user_id, map) in one_time_keys_input {
if user_id.server_name() != db.globals.server_name() { if user_id.server_name() != services().globals.server_name() {
get_over_federation get_over_federation
.entry(user_id.server_name()) .entry(user_id.server_name())
.or_insert_with(Vec::new) .or_insert_with(Vec::new)
@ -434,8 +408,8 @@ pub(crate) async fn claim_keys_helper(
let mut container = BTreeMap::new(); let mut container = BTreeMap::new();
for (device_id, key_algorithm) in map { for (device_id, key_algorithm) in map {
if let Some(one_time_keys) = if let Some(one_time_keys) =
db.users services().users
.take_one_time_key(user_id, device_id, key_algorithm, &db.globals)? .take_one_time_key(user_id, device_id, key_algorithm)?
{ {
let mut c = BTreeMap::new(); let mut c = BTreeMap::new();
c.insert(one_time_keys.0, one_time_keys.1); c.insert(one_time_keys.0, one_time_keys.1);
@ -453,10 +427,9 @@ pub(crate) async fn claim_keys_helper(
one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); one_time_keys_input_fed.insert(user_id.clone(), keys.clone());
} }
// Ignore failures // Ignore failures
if let Ok(keys) = db if let Ok(keys) = services()
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals,
server, server,
federation::keys::claim_keys::v1::Request { federation::keys::claim_keys::v1::Request {
one_time_keys: one_time_keys_input_fed, one_time_keys: one_time_keys_input_fed,

View file

@ -1,6 +1,5 @@
use crate::{ use crate::{
database::{media::FileMeta, DatabaseGuard}, utils, Error, Result, Ruma, services, service::media::FileMeta,
utils, Error, Result, Ruma,
}; };
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
@ -16,11 +15,10 @@ const MXC_LENGTH: usize = 32;
/// ///
/// Returns max upload size. /// Returns max upload size.
pub async fn get_media_config_route( pub async fn get_media_config_route(
db: DatabaseGuard,
_body: Ruma<get_media_config::v3::Request>, _body: Ruma<get_media_config::v3::Request>,
) -> Result<get_media_config::v3::Response> { ) -> Result<get_media_config::v3::Response> {
Ok(get_media_config::v3::Response { Ok(get_media_config::v3::Response {
upload_size: db.globals.max_request_size().into(), upload_size: services().globals.max_request_size().into(),
}) })
} }
@ -31,19 +29,17 @@ pub async fn get_media_config_route(
/// - Some metadata will be saved in the database /// - Some metadata will be saved in the database
/// - Media will be saved in the media/ directory /// - Media will be saved in the media/ directory
pub async fn create_content_route( pub async fn create_content_route(
db: DatabaseGuard,
body: Ruma<create_content::v3::IncomingRequest>, body: Ruma<create_content::v3::IncomingRequest>,
) -> Result<create_content::v3::Response> { ) -> Result<create_content::v3::Response> {
let mxc = format!( let mxc = format!(
"mxc://{}/{}", "mxc://{}/{}",
db.globals.server_name(), services().globals.server_name(),
utils::random_string(MXC_LENGTH) utils::random_string(MXC_LENGTH)
); );
db.media services().media
.create( .create(
mxc.clone(), mxc.clone(),
&db.globals,
&body &body
.filename .filename
.as_ref() .as_ref()
@ -54,8 +50,6 @@ pub async fn create_content_route(
) )
.await?; .await?;
db.flush()?;
Ok(create_content::v3::Response { Ok(create_content::v3::Response {
content_uri: mxc.try_into().expect("Invalid mxc:// URI"), content_uri: mxc.try_into().expect("Invalid mxc:// URI"),
blurhash: None, blurhash: None,
@ -63,15 +57,13 @@ pub async fn create_content_route(
} }
pub async fn get_remote_content( pub async fn get_remote_content(
db: &DatabaseGuard,
mxc: &str, mxc: &str,
server_name: &ruma::ServerName, server_name: &ruma::ServerName,
media_id: &str, media_id: &str,
) -> Result<get_content::v3::Response, Error> { ) -> Result<get_content::v3::Response, Error> {
let content_response = db let content_response = services()
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals,
server_name, server_name,
get_content::v3::Request { get_content::v3::Request {
allow_remote: false, allow_remote: false,
@ -81,10 +73,9 @@ pub async fn get_remote_content(
) )
.await?; .await?;
db.media services().media
.create( .create(
mxc.to_string(), mxc.to_string(),
&db.globals,
&content_response.content_disposition.as_deref(), &content_response.content_disposition.as_deref(),
&content_response.content_type.as_deref(), &content_response.content_type.as_deref(),
&content_response.file, &content_response.file,
@ -100,7 +91,6 @@ pub async fn get_remote_content(
/// ///
/// - Only allows federation if `allow_remote` is true /// - Only allows federation if `allow_remote` is true
pub async fn get_content_route( pub async fn get_content_route(
db: DatabaseGuard,
body: Ruma<get_content::v3::IncomingRequest>, body: Ruma<get_content::v3::IncomingRequest>,
) -> Result<get_content::v3::Response> { ) -> Result<get_content::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
@ -109,16 +99,16 @@ pub async fn get_content_route(
content_disposition, content_disposition,
content_type, content_type,
file, file,
}) = db.media.get(&db.globals, &mxc).await? }) = services().media.get(&mxc).await?
{ {
Ok(get_content::v3::Response { Ok(get_content::v3::Response {
file, file,
content_type, content_type,
content_disposition, content_disposition,
}) })
} else if &*body.server_name != db.globals.server_name() && body.allow_remote { } else if &*body.server_name != services().globals.server_name() && body.allow_remote {
let remote_content_response = let remote_content_response =
get_remote_content(&db, &mxc, &body.server_name, &body.media_id).await?; get_remote_content(&mxc, &body.server_name, &body.media_id).await?;
Ok(remote_content_response) Ok(remote_content_response)
} else { } else {
Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."))
@ -131,7 +121,6 @@ pub async fn get_content_route(
/// ///
/// - Only allows federation if `allow_remote` is true /// - Only allows federation if `allow_remote` is true
pub async fn get_content_as_filename_route( pub async fn get_content_as_filename_route(
db: DatabaseGuard,
body: Ruma<get_content_as_filename::v3::IncomingRequest>, body: Ruma<get_content_as_filename::v3::IncomingRequest>,
) -> Result<get_content_as_filename::v3::Response> { ) -> Result<get_content_as_filename::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
@ -140,16 +129,16 @@ pub async fn get_content_as_filename_route(
content_disposition: _, content_disposition: _,
content_type, content_type,
file, file,
}) = db.media.get(&db.globals, &mxc).await? }) = services().media.get(&mxc).await?
{ {
Ok(get_content_as_filename::v3::Response { Ok(get_content_as_filename::v3::Response {
file, file,
content_type, content_type,
content_disposition: Some(format!("inline; filename={}", body.filename)), content_disposition: Some(format!("inline; filename={}", body.filename)),
}) })
} else if &*body.server_name != db.globals.server_name() && body.allow_remote { } else if &*body.server_name != services().globals.server_name() && body.allow_remote {
let remote_content_response = let remote_content_response =
get_remote_content(&db, &mxc, &body.server_name, &body.media_id).await?; get_remote_content(&mxc, &body.server_name, &body.media_id).await?;
Ok(get_content_as_filename::v3::Response { Ok(get_content_as_filename::v3::Response {
content_disposition: Some(format!("inline: filename={}", body.filename)), content_disposition: Some(format!("inline: filename={}", body.filename)),
@ -167,18 +156,16 @@ pub async fn get_content_as_filename_route(
/// ///
/// - Only allows federation if `allow_remote` is true /// - Only allows federation if `allow_remote` is true
pub async fn get_content_thumbnail_route( pub async fn get_content_thumbnail_route(
db: DatabaseGuard,
body: Ruma<get_content_thumbnail::v3::IncomingRequest>, body: Ruma<get_content_thumbnail::v3::IncomingRequest>,
) -> Result<get_content_thumbnail::v3::Response> { ) -> Result<get_content_thumbnail::v3::Response> {
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
if let Some(FileMeta { if let Some(FileMeta {
content_type, file, .. content_type, file, ..
}) = db }) = services()
.media .media
.get_thumbnail( .get_thumbnail(
&mxc, &mxc,
&db.globals,
body.width body.width
.try_into() .try_into()
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
@ -189,11 +176,10 @@ pub async fn get_content_thumbnail_route(
.await? .await?
{ {
Ok(get_content_thumbnail::v3::Response { file, content_type }) Ok(get_content_thumbnail::v3::Response { file, content_type })
} else if &*body.server_name != db.globals.server_name() && body.allow_remote { } else if &*body.server_name != services().globals.server_name() && body.allow_remote {
let get_thumbnail_response = db let get_thumbnail_response = services()
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals,
&body.server_name, &body.server_name,
get_content_thumbnail::v3::Request { get_content_thumbnail::v3::Request {
allow_remote: false, allow_remote: false,
@ -206,10 +192,9 @@ pub async fn get_content_thumbnail_route(
) )
.await?; .await?;
db.media services().media
.upload_thumbnail( .upload_thumbnail(
mxc, mxc,
&db.globals,
&None, &None,
&get_thumbnail_response.content_type, &get_thumbnail_response.content_type,
body.width.try_into().expect("all UInts are valid u32s"), body.width.try_into().expect("all UInts are valid u32s"),

File diff suppressed because it is too large Load diff

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, Error, Result, Ruma}; use crate::{utils, Error, Result, Ruma, services, service::pdu::PduBuilder};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -19,14 +19,13 @@ use std::{
/// - The only requirement for the content is that it has to be valid json /// - The only requirement for the content is that it has to be valid json
/// - Tries to send the event into the room, auth rules will determine if it is allowed /// - Tries to send the event into the room, auth rules will determine if it is allowed
pub async fn send_message_event_route( pub async fn send_message_event_route(
db: DatabaseGuard,
body: Ruma<send_message_event::v3::IncomingRequest>, body: Ruma<send_message_event::v3::IncomingRequest>,
) -> Result<send_message_event::v3::Response> { ) -> Result<send_message_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_deref(); let sender_device = body.sender_device.as_deref();
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -37,7 +36,7 @@ pub async fn send_message_event_route(
// Forbid m.room.encrypted if encryption is disabled // Forbid m.room.encrypted if encryption is disabled
if RoomEventType::RoomEncrypted == body.event_type.to_string().into() if RoomEventType::RoomEncrypted == body.event_type.to_string().into()
&& !db.globals.allow_encryption() && !services().globals.allow_encryption()
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
@ -47,7 +46,7 @@ pub async fn send_message_event_route(
// Check if this is a new transaction id // Check if this is a new transaction id
if let Some(response) = if let Some(response) =
db.transaction_ids services().transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)? .existing_txnid(sender_user, sender_device, &body.txn_id)?
{ {
// The client might have sent a txnid of the /sendToDevice endpoint // The client might have sent a txnid of the /sendToDevice endpoint
@ -69,7 +68,7 @@ pub async fn send_message_event_route(
let mut unsigned = BTreeMap::new(); let mut unsigned = BTreeMap::new();
unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into());
let event_id = db.rooms.build_and_append_pdu( let event_id = services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: body.event_type.to_string().into(), event_type: body.event_type.to_string().into(),
content: serde_json::from_str(body.body.body.json().get()) content: serde_json::from_str(body.body.body.json().get())
@ -80,11 +79,10 @@ pub async fn send_message_event_route(
}, },
sender_user, sender_user,
&body.room_id, &body.room_id,
&db,
&state_lock, &state_lock,
)?; )?;
db.transaction_ids.add_txnid( services().transaction_ids.add_txnid(
sender_user, sender_user,
sender_device, sender_device,
&body.txn_id, &body.txn_id,
@ -93,8 +91,6 @@ pub async fn send_message_event_route(
drop(state_lock); drop(state_lock);
db.flush()?;
Ok(send_message_event::v3::Response::new( Ok(send_message_event::v3::Response::new(
(*event_id).to_owned(), (*event_id).to_owned(),
)) ))
@ -107,13 +103,12 @@ pub async fn send_message_event_route(
/// - Only works if the user is joined (TODO: always allow, but only show events where the user was /// - Only works if the user is joined (TODO: always allow, but only show events where the user was
/// joined, depending on history_visibility) /// joined, depending on history_visibility)
pub async fn get_message_events_route( pub async fn get_message_events_route(
db: DatabaseGuard,
body: Ruma<get_message_events::v3::IncomingRequest>, body: Ruma<get_message_events::v3::IncomingRequest>,
) -> Result<get_message_events::v3::Response> { ) -> Result<get_message_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
if !db.rooms.is_joined(sender_user, &body.room_id)? { if !services().rooms.is_joined(sender_user, &body.room_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -133,7 +128,7 @@ pub async fn get_message_events_route(
let to = body.to.as_ref().map(|t| t.parse()); let to = body.to.as_ref().map(|t| t.parse());
db.rooms services().rooms
.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?; .lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?;
// Use limit or else 10 // Use limit or else 10
@ -147,13 +142,13 @@ pub async fn get_message_events_route(
match body.dir { match body.dir {
get_message_events::v3::Direction::Forward => { get_message_events::v3::Direction::Forward => {
let events_after: Vec<_> = db let events_after: Vec<_> = services()
.rooms .rooms
.pdus_after(sender_user, &body.room_id, from)? .pdus_after(sender_user, &body.room_id, from)?
.take(limit) .take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|r| r.ok()) // Filter out buggy events
.filter_map(|(pdu_id, pdu)| { .filter_map(|(pdu_id, pdu)| {
db.rooms services().rooms
.pdu_count(&pdu_id) .pdu_count(&pdu_id)
.map(|pdu_count| (pdu_count, pdu)) .map(|pdu_count| (pdu_count, pdu))
.ok() .ok()
@ -162,7 +157,7 @@ pub async fn get_message_events_route(
.collect(); .collect();
for (_, event) in &events_after { for (_, event) in &events_after {
if !db.rooms.lazy_load_was_sent_before( if !services().rooms.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&body.room_id, &body.room_id,
@ -184,13 +179,13 @@ pub async fn get_message_events_route(
resp.chunk = events_after; resp.chunk = events_after;
} }
get_message_events::v3::Direction::Backward => { get_message_events::v3::Direction::Backward => {
let events_before: Vec<_> = db let events_before: Vec<_> = services()
.rooms .rooms
.pdus_until(sender_user, &body.room_id, from)? .pdus_until(sender_user, &body.room_id, from)?
.take(limit) .take(limit)
.filter_map(|r| r.ok()) // Filter out buggy events .filter_map(|r| r.ok()) // Filter out buggy events
.filter_map(|(pdu_id, pdu)| { .filter_map(|(pdu_id, pdu)| {
db.rooms services().rooms
.pdu_count(&pdu_id) .pdu_count(&pdu_id)
.map(|pdu_count| (pdu_count, pdu)) .map(|pdu_count| (pdu_count, pdu))
.ok() .ok()
@ -199,7 +194,7 @@ pub async fn get_message_events_route(
.collect(); .collect();
for (_, event) in &events_before { for (_, event) in &events_before {
if !db.rooms.lazy_load_was_sent_before( if !services().rooms.lazy_load_was_sent_before(
sender_user, sender_user,
sender_device, sender_device,
&body.room_id, &body.room_id,
@ -225,7 +220,7 @@ pub async fn get_message_events_route(
resp.state = Vec::new(); resp.state = Vec::new();
for ll_id in &lazy_loaded { for ll_id in &lazy_loaded {
if let Some(member_event) = if let Some(member_event) =
db.rooms services().rooms
.room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())? .room_state_get(&body.room_id, &StateEventType::RoomMember, ll_id.as_str())?
{ {
resp.state.push(member_event.to_state_event()); resp.state.push(member_event.to_state_event());
@ -233,7 +228,7 @@ pub async fn get_message_events_route(
} }
if let Some(next_token) = next_token { if let Some(next_token) = next_token {
db.rooms.lazy_load_mark_sent( services().rooms.lazy_load_mark_sent(
sender_user, sender_user,
sender_device, sender_device,
&body.room_id, &body.room_id,

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, utils, Result, Ruma}; use crate::{utils, Result, Ruma, services};
use ruma::api::client::presence::{get_presence, set_presence}; use ruma::api::client::presence::{get_presence, set_presence};
use std::time::Duration; use std::time::Duration;
@ -6,22 +6,21 @@ use std::time::Duration;
/// ///
/// Sets the presence state of the sender user. /// Sets the presence state of the sender user.
pub async fn set_presence_route( pub async fn set_presence_route(
db: DatabaseGuard,
body: Ruma<set_presence::v3::IncomingRequest>, body: Ruma<set_presence::v3::IncomingRequest>,
) -> Result<set_presence::v3::Response> { ) -> Result<set_presence::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
for room_id in db.rooms.rooms_joined(sender_user) { for room_id in services().rooms.rooms_joined(sender_user) {
let room_id = room_id?; let room_id = room_id?;
db.rooms.edus.update_presence( services().rooms.edus.update_presence(
sender_user, sender_user,
&room_id, &room_id,
ruma::events::presence::PresenceEvent { ruma::events::presence::PresenceEvent {
content: ruma::events::presence::PresenceEventContent { content: ruma::events::presence::PresenceEventContent {
avatar_url: db.users.avatar_url(sender_user)?, avatar_url: services().users.avatar_url(sender_user)?,
currently_active: None, currently_active: None,
displayname: db.users.displayname(sender_user)?, displayname: services().users.displayname(sender_user)?,
last_active_ago: Some( last_active_ago: Some(
utils::millis_since_unix_epoch() utils::millis_since_unix_epoch()
.try_into() .try_into()
@ -32,12 +31,9 @@ pub async fn set_presence_route(
}, },
sender: sender_user.clone(), sender: sender_user.clone(),
}, },
&db.globals,
)?; )?;
} }
db.flush()?;
Ok(set_presence::v3::Response {}) Ok(set_presence::v3::Response {})
} }
@ -47,20 +43,19 @@ pub async fn set_presence_route(
/// ///
/// - Only works if you share a room with the user /// - Only works if you share a room with the user
pub async fn get_presence_route( pub async fn get_presence_route(
db: DatabaseGuard,
body: Ruma<get_presence::v3::IncomingRequest>, body: Ruma<get_presence::v3::IncomingRequest>,
) -> Result<get_presence::v3::Response> { ) -> Result<get_presence::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mut presence_event = None; let mut presence_event = None;
for room_id in db for room_id in services()
.rooms .rooms
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
{ {
let room_id = room_id?; let room_id = room_id?;
if let Some(presence) = db if let Some(presence) = services()
.rooms .rooms
.edus .edus
.get_last_presence_event(sender_user, &room_id)? .get_last_presence_event(sender_user, &room_id)?

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, pdu::PduBuilder, utils, Error, Result, Ruma}; use crate::{utils, Error, Result, Ruma, services, service::pdu::PduBuilder};
use ruma::{ use ruma::{
api::{ api::{
client::{ client::{
@ -20,16 +20,15 @@ use std::sync::Arc;
/// ///
/// - Also makes sure other users receive the update using presence EDUs /// - Also makes sure other users receive the update using presence EDUs
pub async fn set_displayname_route( pub async fn set_displayname_route(
db: DatabaseGuard,
body: Ruma<set_display_name::v3::IncomingRequest>, body: Ruma<set_display_name::v3::IncomingRequest>,
) -> Result<set_display_name::v3::Response> { ) -> Result<set_display_name::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
db.users services().users
.set_displayname(sender_user, body.displayname.clone())?; .set_displayname(sender_user, body.displayname.clone())?;
// Send a new membership event and presence update into all joined rooms // Send a new membership event and presence update into all joined rooms
let all_rooms_joined: Vec<_> = db let all_rooms_joined: Vec<_> = services()
.rooms .rooms
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
@ -40,7 +39,7 @@ pub async fn set_displayname_route(
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
displayname: body.displayname.clone(), displayname: body.displayname.clone(),
..serde_json::from_str( ..serde_json::from_str(
db.rooms services().rooms
.room_state_get( .room_state_get(
&room_id, &room_id,
&StateEventType::RoomMember, &StateEventType::RoomMember,
@ -70,7 +69,7 @@ pub async fn set_displayname_route(
for (pdu_builder, room_id) in all_rooms_joined { for (pdu_builder, room_id) in all_rooms_joined {
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -79,19 +78,19 @@ pub async fn set_displayname_route(
); );
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
let _ = db let _ = services()
.rooms .rooms
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock); .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock);
// Presence update // Presence update
db.rooms.edus.update_presence( services().rooms.edus.update_presence(
sender_user, sender_user,
&room_id, &room_id,
ruma::events::presence::PresenceEvent { ruma::events::presence::PresenceEvent {
content: ruma::events::presence::PresenceEventContent { content: ruma::events::presence::PresenceEventContent {
avatar_url: db.users.avatar_url(sender_user)?, avatar_url: services().users.avatar_url(sender_user)?,
currently_active: None, currently_active: None,
displayname: db.users.displayname(sender_user)?, displayname: services().users.displayname(sender_user)?,
last_active_ago: Some( last_active_ago: Some(
utils::millis_since_unix_epoch() utils::millis_since_unix_epoch()
.try_into() .try_into()
@ -102,12 +101,9 @@ pub async fn set_displayname_route(
}, },
sender: sender_user.clone(), sender: sender_user.clone(),
}, },
&db.globals,
)?; )?;
} }
db.flush()?;
Ok(set_display_name::v3::Response {}) Ok(set_display_name::v3::Response {})
} }
@ -117,14 +113,12 @@ pub async fn set_displayname_route(
/// ///
/// - If user is on another server: Fetches displayname over federation /// - If user is on another server: Fetches displayname over federation
pub async fn get_displayname_route( pub async fn get_displayname_route(
db: DatabaseGuard,
body: Ruma<get_display_name::v3::IncomingRequest>, body: Ruma<get_display_name::v3::IncomingRequest>,
) -> Result<get_display_name::v3::Response> { ) -> Result<get_display_name::v3::Response> {
if body.user_id.server_name() != db.globals.server_name() { if body.user_id.server_name() != services().globals.server_name() {
let response = db let response = services()
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals,
body.user_id.server_name(), body.user_id.server_name(),
federation::query::get_profile_information::v1::Request { federation::query::get_profile_information::v1::Request {
user_id: &body.user_id, user_id: &body.user_id,
@ -139,7 +133,7 @@ pub async fn get_displayname_route(
} }
Ok(get_display_name::v3::Response { Ok(get_display_name::v3::Response {
displayname: db.users.displayname(&body.user_id)?, displayname: services().users.displayname(&body.user_id)?,
}) })
} }
@ -149,18 +143,17 @@ pub async fn get_displayname_route(
/// ///
/// - Also makes sure other users receive the update using presence EDUs /// - Also makes sure other users receive the update using presence EDUs
pub async fn set_avatar_url_route( pub async fn set_avatar_url_route(
db: DatabaseGuard,
body: Ruma<set_avatar_url::v3::IncomingRequest>, body: Ruma<set_avatar_url::v3::IncomingRequest>,
) -> Result<set_avatar_url::v3::Response> { ) -> Result<set_avatar_url::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
db.users services().users
.set_avatar_url(sender_user, body.avatar_url.clone())?; .set_avatar_url(sender_user, body.avatar_url.clone())?;
db.users.set_blurhash(sender_user, body.blurhash.clone())?; services().users.set_blurhash(sender_user, body.blurhash.clone())?;
// Send a new membership event and presence update into all joined rooms // Send a new membership event and presence update into all joined rooms
let all_joined_rooms: Vec<_> = db let all_joined_rooms: Vec<_> = services()
.rooms .rooms
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
@ -171,7 +164,7 @@ pub async fn set_avatar_url_route(
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
avatar_url: body.avatar_url.clone(), avatar_url: body.avatar_url.clone(),
..serde_json::from_str( ..serde_json::from_str(
db.rooms services().rooms
.room_state_get( .room_state_get(
&room_id, &room_id,
&StateEventType::RoomMember, &StateEventType::RoomMember,
@ -201,7 +194,7 @@ pub async fn set_avatar_url_route(
for (pdu_builder, room_id) in all_joined_rooms { for (pdu_builder, room_id) in all_joined_rooms {
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -210,19 +203,19 @@ pub async fn set_avatar_url_route(
); );
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
let _ = db let _ = services()
.rooms .rooms
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock); .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock);
// Presence update // Presence update
db.rooms.edus.update_presence( services().rooms.edus.update_presence(
sender_user, sender_user,
&room_id, &room_id,
ruma::events::presence::PresenceEvent { ruma::events::presence::PresenceEvent {
content: ruma::events::presence::PresenceEventContent { content: ruma::events::presence::PresenceEventContent {
avatar_url: db.users.avatar_url(sender_user)?, avatar_url: services().users.avatar_url(sender_user)?,
currently_active: None, currently_active: None,
displayname: db.users.displayname(sender_user)?, displayname: services().users.displayname(sender_user)?,
last_active_ago: Some( last_active_ago: Some(
utils::millis_since_unix_epoch() utils::millis_since_unix_epoch()
.try_into() .try_into()
@ -233,12 +226,10 @@ pub async fn set_avatar_url_route(
}, },
sender: sender_user.clone(), sender: sender_user.clone(),
}, },
&db.globals, &services().globals,
)?; )?;
} }
db.flush()?;
Ok(set_avatar_url::v3::Response {}) Ok(set_avatar_url::v3::Response {})
} }
@ -248,14 +239,12 @@ pub async fn set_avatar_url_route(
/// ///
/// - If user is on another server: Fetches avatar_url and blurhash over federation /// - If user is on another server: Fetches avatar_url and blurhash over federation
pub async fn get_avatar_url_route( pub async fn get_avatar_url_route(
db: DatabaseGuard,
body: Ruma<get_avatar_url::v3::IncomingRequest>, body: Ruma<get_avatar_url::v3::IncomingRequest>,
) -> Result<get_avatar_url::v3::Response> { ) -> Result<get_avatar_url::v3::Response> {
if body.user_id.server_name() != db.globals.server_name() { if body.user_id.server_name() != services().globals.server_name() {
let response = db let response = services()
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals,
body.user_id.server_name(), body.user_id.server_name(),
federation::query::get_profile_information::v1::Request { federation::query::get_profile_information::v1::Request {
user_id: &body.user_id, user_id: &body.user_id,
@ -271,8 +260,8 @@ pub async fn get_avatar_url_route(
} }
Ok(get_avatar_url::v3::Response { Ok(get_avatar_url::v3::Response {
avatar_url: db.users.avatar_url(&body.user_id)?, avatar_url: services().users.avatar_url(&body.user_id)?,
blurhash: db.users.blurhash(&body.user_id)?, blurhash: services().users.blurhash(&body.user_id)?,
}) })
} }
@ -282,14 +271,12 @@ pub async fn get_avatar_url_route(
/// ///
/// - If user is on another server: Fetches profile over federation /// - If user is on another server: Fetches profile over federation
pub async fn get_profile_route( pub async fn get_profile_route(
db: DatabaseGuard,
body: Ruma<get_profile::v3::IncomingRequest>, body: Ruma<get_profile::v3::IncomingRequest>,
) -> Result<get_profile::v3::Response> { ) -> Result<get_profile::v3::Response> {
if body.user_id.server_name() != db.globals.server_name() { if body.user_id.server_name() != services().globals.server_name() {
let response = db let response = services()
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals,
body.user_id.server_name(), body.user_id.server_name(),
federation::query::get_profile_information::v1::Request { federation::query::get_profile_information::v1::Request {
user_id: &body.user_id, user_id: &body.user_id,
@ -305,7 +292,7 @@ pub async fn get_profile_route(
}); });
} }
if !db.users.exists(&body.user_id)? { if !services().users.exists(&body.user_id)? {
// Return 404 if this user doesn't exist // Return 404 if this user doesn't exist
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::NotFound, ErrorKind::NotFound,
@ -314,8 +301,8 @@ pub async fn get_profile_route(
} }
Ok(get_profile::v3::Response { Ok(get_profile::v3::Response {
avatar_url: db.users.avatar_url(&body.user_id)?, avatar_url: services().users.avatar_url(&body.user_id)?,
blurhash: db.users.blurhash(&body.user_id)?, blurhash: services().users.blurhash(&body.user_id)?,
displayname: db.users.displayname(&body.user_id)?, displayname: services().users.displayname(&body.user_id)?,
}) })
} }

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Error, Result, Ruma}; use crate::{Error, Result, Ruma, services};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -16,12 +16,11 @@ use ruma::{
/// ///
/// Retrieves the push rules event for this user. /// Retrieves the push rules event for this user.
pub async fn get_pushrules_all_route( pub async fn get_pushrules_all_route(
db: DatabaseGuard,
body: Ruma<get_pushrules_all::v3::Request>, body: Ruma<get_pushrules_all::v3::Request>,
) -> Result<get_pushrules_all::v3::Response> { ) -> Result<get_pushrules_all::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event: PushRulesEvent = db let event: PushRulesEvent = services()
.account_data .account_data
.get( .get(
None, None,
@ -42,12 +41,11 @@ pub async fn get_pushrules_all_route(
/// ///
/// Retrieves a single specified push rule for this user. /// Retrieves a single specified push rule for this user.
pub async fn get_pushrule_route( pub async fn get_pushrule_route(
db: DatabaseGuard,
body: Ruma<get_pushrule::v3::IncomingRequest>, body: Ruma<get_pushrule::v3::IncomingRequest>,
) -> Result<get_pushrule::v3::Response> { ) -> Result<get_pushrule::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event: PushRulesEvent = db let event: PushRulesEvent = services()
.account_data .account_data
.get( .get(
None, None,
@ -98,7 +96,6 @@ pub async fn get_pushrule_route(
/// ///
/// Creates a single specified push rule for this user. /// Creates a single specified push rule for this user.
pub async fn set_pushrule_route( pub async fn set_pushrule_route(
db: DatabaseGuard,
body: Ruma<set_pushrule::v3::IncomingRequest>, body: Ruma<set_pushrule::v3::IncomingRequest>,
) -> Result<set_pushrule::v3::Response> { ) -> Result<set_pushrule::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -111,7 +108,7 @@ pub async fn set_pushrule_route(
)); ));
} }
let mut event: PushRulesEvent = db let mut event: PushRulesEvent = services()
.account_data .account_data
.get( .get(
None, None,
@ -186,16 +183,13 @@ pub async fn set_pushrule_route(
_ => {} _ => {}
} }
db.account_data.update( services().account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&event, &event,
&db.globals,
)?; )?;
db.flush()?;
Ok(set_pushrule::v3::Response {}) Ok(set_pushrule::v3::Response {})
} }
@ -203,7 +197,6 @@ pub async fn set_pushrule_route(
/// ///
/// Gets the actions of a single specified push rule for this user. /// Gets the actions of a single specified push rule for this user.
pub async fn get_pushrule_actions_route( pub async fn get_pushrule_actions_route(
db: DatabaseGuard,
body: Ruma<get_pushrule_actions::v3::IncomingRequest>, body: Ruma<get_pushrule_actions::v3::IncomingRequest>,
) -> Result<get_pushrule_actions::v3::Response> { ) -> Result<get_pushrule_actions::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -215,7 +208,7 @@ pub async fn get_pushrule_actions_route(
)); ));
} }
let mut event: PushRulesEvent = db let mut event: PushRulesEvent = services()
.account_data .account_data
.get( .get(
None, None,
@ -252,8 +245,6 @@ pub async fn get_pushrule_actions_route(
_ => None, _ => None,
}; };
db.flush()?;
Ok(get_pushrule_actions::v3::Response { Ok(get_pushrule_actions::v3::Response {
actions: actions.unwrap_or_default(), actions: actions.unwrap_or_default(),
}) })
@ -263,7 +254,6 @@ pub async fn get_pushrule_actions_route(
/// ///
/// Sets the actions of a single specified push rule for this user. /// Sets the actions of a single specified push rule for this user.
pub async fn set_pushrule_actions_route( pub async fn set_pushrule_actions_route(
db: DatabaseGuard,
body: Ruma<set_pushrule_actions::v3::IncomingRequest>, body: Ruma<set_pushrule_actions::v3::IncomingRequest>,
) -> Result<set_pushrule_actions::v3::Response> { ) -> Result<set_pushrule_actions::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -275,7 +265,7 @@ pub async fn set_pushrule_actions_route(
)); ));
} }
let mut event: PushRulesEvent = db let mut event: PushRulesEvent = services()
.account_data .account_data
.get( .get(
None, None,
@ -322,16 +312,13 @@ pub async fn set_pushrule_actions_route(
_ => {} _ => {}
}; };
db.account_data.update( services().account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&event, &event,
&db.globals,
)?; )?;
db.flush()?;
Ok(set_pushrule_actions::v3::Response {}) Ok(set_pushrule_actions::v3::Response {})
} }
@ -339,7 +326,6 @@ pub async fn set_pushrule_actions_route(
/// ///
/// Gets the enabled status of a single specified push rule for this user. /// Gets the enabled status of a single specified push rule for this user.
pub async fn get_pushrule_enabled_route( pub async fn get_pushrule_enabled_route(
db: DatabaseGuard,
body: Ruma<get_pushrule_enabled::v3::IncomingRequest>, body: Ruma<get_pushrule_enabled::v3::IncomingRequest>,
) -> Result<get_pushrule_enabled::v3::Response> { ) -> Result<get_pushrule_enabled::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -351,7 +337,7 @@ pub async fn get_pushrule_enabled_route(
)); ));
} }
let mut event: PushRulesEvent = db let mut event: PushRulesEvent = services()
.account_data .account_data
.get( .get(
None, None,
@ -393,8 +379,6 @@ pub async fn get_pushrule_enabled_route(
_ => false, _ => false,
}; };
db.flush()?;
Ok(get_pushrule_enabled::v3::Response { enabled }) Ok(get_pushrule_enabled::v3::Response { enabled })
} }
@ -402,7 +386,6 @@ pub async fn get_pushrule_enabled_route(
/// ///
/// Sets the enabled status of a single specified push rule for this user. /// Sets the enabled status of a single specified push rule for this user.
pub async fn set_pushrule_enabled_route( pub async fn set_pushrule_enabled_route(
db: DatabaseGuard,
body: Ruma<set_pushrule_enabled::v3::IncomingRequest>, body: Ruma<set_pushrule_enabled::v3::IncomingRequest>,
) -> Result<set_pushrule_enabled::v3::Response> { ) -> Result<set_pushrule_enabled::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -414,7 +397,7 @@ pub async fn set_pushrule_enabled_route(
)); ));
} }
let mut event: PushRulesEvent = db let mut event: PushRulesEvent = services()
.account_data .account_data
.get( .get(
None, None,
@ -466,16 +449,13 @@ pub async fn set_pushrule_enabled_route(
_ => {} _ => {}
} }
db.account_data.update( services().account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&event, &event,
&db.globals,
)?; )?;
db.flush()?;
Ok(set_pushrule_enabled::v3::Response {}) Ok(set_pushrule_enabled::v3::Response {})
} }
@ -483,7 +463,6 @@ pub async fn set_pushrule_enabled_route(
/// ///
/// Deletes a single specified push rule for this user. /// Deletes a single specified push rule for this user.
pub async fn delete_pushrule_route( pub async fn delete_pushrule_route(
db: DatabaseGuard,
body: Ruma<delete_pushrule::v3::IncomingRequest>, body: Ruma<delete_pushrule::v3::IncomingRequest>,
) -> Result<delete_pushrule::v3::Response> { ) -> Result<delete_pushrule::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -495,7 +474,7 @@ pub async fn delete_pushrule_route(
)); ));
} }
let mut event: PushRulesEvent = db let mut event: PushRulesEvent = services()
.account_data .account_data
.get( .get(
None, None,
@ -537,16 +516,13 @@ pub async fn delete_pushrule_route(
_ => {} _ => {}
} }
db.account_data.update( services().account_data.update(
None, None,
sender_user, sender_user,
GlobalAccountDataEventType::PushRules.to_string().into(), GlobalAccountDataEventType::PushRules.to_string().into(),
&event, &event,
&db.globals,
)?; )?;
db.flush()?;
Ok(delete_pushrule::v3::Response {}) Ok(delete_pushrule::v3::Response {})
} }
@ -554,13 +530,12 @@ pub async fn delete_pushrule_route(
/// ///
/// Gets all currently active pushers for the sender user. /// Gets all currently active pushers for the sender user.
pub async fn get_pushers_route( pub async fn get_pushers_route(
db: DatabaseGuard,
body: Ruma<get_pushers::v3::Request>, body: Ruma<get_pushers::v3::Request>,
) -> Result<get_pushers::v3::Response> { ) -> Result<get_pushers::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(get_pushers::v3::Response { Ok(get_pushers::v3::Response {
pushers: db.pusher.get_pushers(sender_user)?, pushers: services().pusher.get_pushers(sender_user)?,
}) })
} }
@ -570,15 +545,12 @@ pub async fn get_pushers_route(
/// ///
/// - TODO: Handle `append` /// - TODO: Handle `append`
pub async fn set_pushers_route( pub async fn set_pushers_route(
db: DatabaseGuard,
body: Ruma<set_pusher::v3::Request>, body: Ruma<set_pusher::v3::Request>,
) -> Result<set_pusher::v3::Response> { ) -> Result<set_pusher::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let pusher = body.pusher.clone(); let pusher = body.pusher.clone();
db.pusher.set_pusher(sender_user, pusher)?; services().pusher.set_pusher(sender_user, pusher)?;
db.flush()?;
Ok(set_pusher::v3::Response::default()) Ok(set_pusher::v3::Response::default())
} }

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Error, Result, Ruma}; use crate::{Error, Result, Ruma, services};
use ruma::{ use ruma::{
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
events::RoomAccountDataEventType, events::RoomAccountDataEventType,
@ -14,7 +14,6 @@ use std::collections::BTreeMap;
/// - Updates fully-read account data event to `fully_read` /// - Updates fully-read account data event to `fully_read`
/// - If `read_receipt` is set: Update private marker and public read receipt EDU /// - If `read_receipt` is set: Update private marker and public read receipt EDU
pub async fn set_read_marker_route( pub async fn set_read_marker_route(
db: DatabaseGuard,
body: Ruma<set_read_marker::v3::IncomingRequest>, body: Ruma<set_read_marker::v3::IncomingRequest>,
) -> Result<set_read_marker::v3::Response> { ) -> Result<set_read_marker::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -24,25 +23,23 @@ pub async fn set_read_marker_route(
event_id: body.fully_read.clone(), event_id: body.fully_read.clone(),
}, },
}; };
db.account_data.update( services().account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::FullyRead, RoomAccountDataEventType::FullyRead,
&fully_read_event, &fully_read_event,
&db.globals,
)?; )?;
if let Some(event) = &body.read_receipt { if let Some(event) = &body.read_receipt {
db.rooms.edus.private_read_set( services().rooms.edus.private_read_set(
&body.room_id, &body.room_id,
sender_user, sender_user,
db.rooms.get_pdu_count(event)?.ok_or(Error::BadRequest( services().rooms.get_pdu_count(event)?.ok_or(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Event does not exist.", "Event does not exist.",
))?, ))?,
&db.globals,
)?; )?;
db.rooms services().rooms
.reset_notification_counts(sender_user, &body.room_id)?; .reset_notification_counts(sender_user, &body.room_id)?;
let mut user_receipts = BTreeMap::new(); let mut user_receipts = BTreeMap::new();
@ -59,19 +56,16 @@ pub async fn set_read_marker_route(
let mut receipt_content = BTreeMap::new(); let mut receipt_content = BTreeMap::new();
receipt_content.insert(event.to_owned(), receipts); receipt_content.insert(event.to_owned(), receipts);
db.rooms.edus.readreceipt_update( services().rooms.edus.readreceipt_update(
sender_user, sender_user,
&body.room_id, &body.room_id,
ruma::events::receipt::ReceiptEvent { ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content), content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(), room_id: body.room_id.clone(),
}, },
&db.globals,
)?; )?;
} }
db.flush()?;
Ok(set_read_marker::v3::Response {}) Ok(set_read_marker::v3::Response {})
} }
@ -79,23 +73,21 @@ pub async fn set_read_marker_route(
/// ///
/// Sets private read marker and public read receipt EDU. /// Sets private read marker and public read receipt EDU.
pub async fn create_receipt_route( pub async fn create_receipt_route(
db: DatabaseGuard,
body: Ruma<create_receipt::v3::IncomingRequest>, body: Ruma<create_receipt::v3::IncomingRequest>,
) -> Result<create_receipt::v3::Response> { ) -> Result<create_receipt::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
db.rooms.edus.private_read_set( services().rooms.edus.private_read_set(
&body.room_id, &body.room_id,
sender_user, sender_user,
db.rooms services().rooms
.get_pdu_count(&body.event_id)? .get_pdu_count(&body.event_id)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
ErrorKind::InvalidParam, ErrorKind::InvalidParam,
"Event does not exist.", "Event does not exist.",
))?, ))?,
&db.globals,
)?; )?;
db.rooms services().rooms
.reset_notification_counts(sender_user, &body.room_id)?; .reset_notification_counts(sender_user, &body.room_id)?;
let mut user_receipts = BTreeMap::new(); let mut user_receipts = BTreeMap::new();
@ -111,17 +103,16 @@ pub async fn create_receipt_route(
let mut receipt_content = BTreeMap::new(); let mut receipt_content = BTreeMap::new();
receipt_content.insert(body.event_id.to_owned(), receipts); receipt_content.insert(body.event_id.to_owned(), receipts);
db.rooms.edus.readreceipt_update( services().rooms.edus.readreceipt_update(
sender_user, sender_user,
&body.room_id, &body.room_id,
ruma::events::receipt::ReceiptEvent { ruma::events::receipt::ReceiptEvent {
content: ruma::events::receipt::ReceiptEventContent(receipt_content), content: ruma::events::receipt::ReceiptEventContent(receipt_content),
room_id: body.room_id.clone(), room_id: body.room_id.clone(),
}, },
&db.globals,
)?; )?;
db.flush()?; services().flush()?;
Ok(create_receipt::v3::Response {}) Ok(create_receipt::v3::Response {})
} }

View file

@ -1,6 +1,6 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{database::DatabaseGuard, pdu::PduBuilder, Result, Ruma}; use crate::{Result, Ruma, services, service::pdu::PduBuilder};
use ruma::{ use ruma::{
api::client::redact::redact_event, api::client::redact::redact_event,
events::{room::redaction::RoomRedactionEventContent, RoomEventType}, events::{room::redaction::RoomRedactionEventContent, RoomEventType},
@ -14,14 +14,13 @@ use serde_json::value::to_raw_value;
/// ///
/// - TODO: Handle txn id /// - TODO: Handle txn id
pub async fn redact_event_route( pub async fn redact_event_route(
db: DatabaseGuard,
body: Ruma<redact_event::v3::IncomingRequest>, body: Ruma<redact_event::v3::IncomingRequest>,
) -> Result<redact_event::v3::Response> { ) -> Result<redact_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let body = body.body; let body = body.body;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -30,7 +29,7 @@ pub async fn redact_event_route(
); );
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
let event_id = db.rooms.build_and_append_pdu( let event_id = services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomRedaction, event_type: RoomEventType::RoomRedaction,
content: to_raw_value(&RoomRedactionEventContent { content: to_raw_value(&RoomRedactionEventContent {
@ -43,14 +42,11 @@ pub async fn redact_event_route(
}, },
sender_user, sender_user,
&body.room_id, &body.room_id,
&db,
&state_lock, &state_lock,
)?; )?;
drop(state_lock); drop(state_lock);
db.flush()?;
let event_id = (*event_id).to_owned(); let event_id = (*event_id).to_owned();
Ok(redact_event::v3::Response { event_id }) Ok(redact_event::v3::Response { event_id })
} }

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, utils::HtmlEscape, Error, Result, Ruma}; use crate::{utils::HtmlEscape, Error, Result, Ruma, services};
use ruma::{ use ruma::{
api::client::{error::ErrorKind, room::report_content}, api::client::{error::ErrorKind, room::report_content},
events::room::message, events::room::message,
@ -10,12 +10,11 @@ use ruma::{
/// Reports an inappropriate event to homeserver admins /// Reports an inappropriate event to homeserver admins
/// ///
pub async fn report_event_route( pub async fn report_event_route(
db: DatabaseGuard,
body: Ruma<report_content::v3::IncomingRequest>, body: Ruma<report_content::v3::IncomingRequest>,
) -> Result<report_content::v3::Response> { ) -> Result<report_content::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let pdu = match db.rooms.get_pdu(&body.event_id)? { let pdu = match services().rooms.get_pdu(&body.event_id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
_ => { _ => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -39,7 +38,7 @@ pub async fn report_event_route(
)); ));
}; };
db.admin services().admin
.send_message(message::RoomMessageEventContent::text_html( .send_message(message::RoomMessageEventContent::text_html(
format!( format!(
"Report received from: {}\n\n\ "Report received from: {}\n\n\
@ -66,7 +65,5 @@ pub async fn report_event_route(
), ),
)); ));
db.flush()?;
Ok(report_content::v3::Response {}) Ok(report_content::v3::Response {})
} }

View file

@ -1,5 +1,5 @@
use crate::{ use crate::{
client_server::invite_helper, database::DatabaseGuard, pdu::PduBuilder, Error, Result, Ruma, Error, Result, Ruma, service::pdu::PduBuilder, services, api::client_server::invite_helper,
}; };
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -46,19 +46,18 @@ use tracing::{info, warn};
/// - Send events implied by `name` and `topic` /// - Send events implied by `name` and `topic`
/// - Send invite events /// - Send invite events
pub async fn create_room_route( pub async fn create_room_route(
db: DatabaseGuard,
body: Ruma<create_room::v3::IncomingRequest>, body: Ruma<create_room::v3::IncomingRequest>,
) -> Result<create_room::v3::Response> { ) -> Result<create_room::v3::Response> {
use create_room::v3::RoomPreset; use create_room::v3::RoomPreset;
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let room_id = RoomId::new(db.globals.server_name()); let room_id = RoomId::new(services().globals.server_name());
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; services().rooms.get_or_create_shortroomid(&room_id)?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -67,9 +66,9 @@ pub async fn create_room_route(
); );
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
if !db.globals.allow_room_creation() if !services().globals.allow_room_creation()
&& !body.from_appservice && !body.from_appservice
&& !db.users.is_admin(sender_user, &db.rooms, &db.globals)? && !services().users.is_admin(sender_user)?
{ {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
@ -83,12 +82,12 @@ pub async fn create_room_route(
.map_or(Ok(None), |localpart| { .map_or(Ok(None), |localpart| {
// TODO: Check for invalid characters and maximum length // TODO: Check for invalid characters and maximum length
let alias = let alias =
RoomAliasId::parse(format!("#{}:{}", localpart, db.globals.server_name())) RoomAliasId::parse(format!("#{}:{}", localpart, services().globals.server_name()))
.map_err(|_| { .map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.") Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias.")
})?; })?;
if db.rooms.id_from_alias(&alias)?.is_some() { if services().rooms.id_from_alias(&alias)?.is_some() {
Err(Error::BadRequest( Err(Error::BadRequest(
ErrorKind::RoomInUse, ErrorKind::RoomInUse,
"Room alias already exists.", "Room alias already exists.",
@ -100,7 +99,7 @@ pub async fn create_room_route(
let room_version = match body.room_version.clone() { let room_version = match body.room_version.clone() {
Some(room_version) => { Some(room_version) => {
if db.rooms.is_supported_version(&db, &room_version) { if services().rooms.is_supported_version(&services(), &room_version) {
room_version room_version
} else { } else {
return Err(Error::BadRequest( return Err(Error::BadRequest(
@ -109,7 +108,7 @@ pub async fn create_room_route(
)); ));
} }
} }
None => db.globals.default_room_version(), None => services().globals.default_room_version(),
}; };
let content = match &body.creation_content { let content = match &body.creation_content {
@ -163,7 +162,7 @@ pub async fn create_room_route(
} }
// 1. The room create event // 1. The room create event
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomCreate, event_type: RoomEventType::RoomCreate,
content: to_raw_value(&content).expect("event is valid, we just created it"), content: to_raw_value(&content).expect("event is valid, we just created it"),
@ -173,21 +172,20 @@ pub async fn create_room_route(
}, },
sender_user, sender_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// 2. Let the room creator join // 2. Let the room creator join
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomMember, event_type: RoomEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join, membership: MembershipState::Join,
displayname: db.users.displayname(sender_user)?, displayname: services().users.displayname(sender_user)?,
avatar_url: db.users.avatar_url(sender_user)?, avatar_url: services().users.avatar_url(sender_user)?,
is_direct: Some(body.is_direct), is_direct: Some(body.is_direct),
third_party_invite: None, third_party_invite: None,
blurhash: db.users.blurhash(sender_user)?, blurhash: services().users.blurhash(sender_user)?,
reason: None, reason: None,
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
}) })
@ -198,7 +196,6 @@ pub async fn create_room_route(
}, },
sender_user, sender_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
@ -240,7 +237,7 @@ pub async fn create_room_route(
} }
} }
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomPowerLevels, event_type: RoomEventType::RoomPowerLevels,
content: to_raw_value(&power_levels_content) content: to_raw_value(&power_levels_content)
@ -251,13 +248,12 @@ pub async fn create_room_route(
}, },
sender_user, sender_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// 4. Canonical room alias // 4. Canonical room alias
if let Some(room_alias_id) = &alias { if let Some(room_alias_id) = &alias {
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomCanonicalAlias, event_type: RoomEventType::RoomCanonicalAlias,
content: to_raw_value(&RoomCanonicalAliasEventContent { content: to_raw_value(&RoomCanonicalAliasEventContent {
@ -271,7 +267,6 @@ pub async fn create_room_route(
}, },
sender_user, sender_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
} }
@ -279,7 +274,7 @@ pub async fn create_room_route(
// 5. Events set by preset // 5. Events set by preset
// 5.1 Join Rules // 5.1 Join Rules
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomJoinRules, event_type: RoomEventType::RoomJoinRules,
content: to_raw_value(&RoomJoinRulesEventContent::new(match preset { content: to_raw_value(&RoomJoinRulesEventContent::new(match preset {
@ -294,12 +289,11 @@ pub async fn create_room_route(
}, },
sender_user, sender_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// 5.2 History Visibility // 5.2 History Visibility
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomHistoryVisibility, event_type: RoomEventType::RoomHistoryVisibility,
content: to_raw_value(&RoomHistoryVisibilityEventContent::new( content: to_raw_value(&RoomHistoryVisibilityEventContent::new(
@ -312,12 +306,11 @@ pub async fn create_room_route(
}, },
sender_user, sender_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// 5.3 Guest Access // 5.3 Guest Access
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomGuestAccess, event_type: RoomEventType::RoomGuestAccess,
content: to_raw_value(&RoomGuestAccessEventContent::new(match preset { content: to_raw_value(&RoomGuestAccessEventContent::new(match preset {
@ -331,7 +324,6 @@ pub async fn create_room_route(
}, },
sender_user, sender_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
@ -346,18 +338,18 @@ pub async fn create_room_route(
pdu_builder.state_key.get_or_insert_with(|| "".to_owned()); pdu_builder.state_key.get_or_insert_with(|| "".to_owned());
// Silently skip encryption events if they are not allowed // Silently skip encryption events if they are not allowed
if pdu_builder.event_type == RoomEventType::RoomEncryption && !db.globals.allow_encryption() if pdu_builder.event_type == RoomEventType::RoomEncryption && !services().globals.allow_encryption()
{ {
continue; continue;
} }
db.rooms services().rooms
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &db, &state_lock)?; .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)?;
} }
// 7. Events implied by name and topic // 7. Events implied by name and topic
if let Some(name) = &body.name { if let Some(name) = &body.name {
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomName, event_type: RoomEventType::RoomName,
content: to_raw_value(&RoomNameEventContent::new(Some(name.clone()))) content: to_raw_value(&RoomNameEventContent::new(Some(name.clone())))
@ -368,13 +360,12 @@ pub async fn create_room_route(
}, },
sender_user, sender_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
} }
if let Some(topic) = &body.topic { if let Some(topic) = &body.topic {
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomTopic, event_type: RoomEventType::RoomTopic,
content: to_raw_value(&RoomTopicEventContent { content: to_raw_value(&RoomTopicEventContent {
@ -387,7 +378,6 @@ pub async fn create_room_route(
}, },
sender_user, sender_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
} }
@ -395,22 +385,20 @@ pub async fn create_room_route(
// 8. Events implied by invite (and TODO: invite_3pid) // 8. Events implied by invite (and TODO: invite_3pid)
drop(state_lock); drop(state_lock);
for user_id in &body.invite { for user_id in &body.invite {
let _ = invite_helper(sender_user, user_id, &room_id, &db, body.is_direct).await; let _ = invite_helper(sender_user, user_id, &room_id, body.is_direct).await;
} }
// Homeserver specific stuff // Homeserver specific stuff
if let Some(alias) = alias { if let Some(alias) = alias {
db.rooms.set_alias(&alias, Some(&room_id), &db.globals)?; services().rooms.set_alias(&alias, Some(&room_id))?;
} }
if body.visibility == room::Visibility::Public { if body.visibility == room::Visibility::Public {
db.rooms.set_public(&room_id, true)?; services().rooms.set_public(&room_id, true)?;
} }
info!("{} created a room", sender_user); info!("{} created a room", sender_user);
db.flush()?;
Ok(create_room::v3::Response::new(room_id)) Ok(create_room::v3::Response::new(room_id))
} }
@ -420,12 +408,11 @@ pub async fn create_room_route(
/// ///
/// - You have to currently be joined to the room (TODO: Respect history visibility) /// - You have to currently be joined to the room (TODO: Respect history visibility)
pub async fn get_room_event_route( pub async fn get_room_event_route(
db: DatabaseGuard,
body: Ruma<get_room_event::v3::IncomingRequest>, body: Ruma<get_room_event::v3::IncomingRequest>,
) -> Result<get_room_event::v3::Response> { ) -> Result<get_room_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !db.rooms.is_joined(sender_user, &body.room_id)? { if !services().rooms.is_joined(sender_user, &body.room_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -433,7 +420,7 @@ pub async fn get_room_event_route(
} }
Ok(get_room_event::v3::Response { Ok(get_room_event::v3::Response {
event: db event: services()
.rooms .rooms
.get_pdu(&body.event_id)? .get_pdu(&body.event_id)?
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))? .ok_or(Error::BadRequest(ErrorKind::NotFound, "Event not found."))?
@ -447,12 +434,11 @@ pub async fn get_room_event_route(
/// ///
/// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if history_visibility is world readable /// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if history_visibility is world readable
pub async fn get_room_aliases_route( pub async fn get_room_aliases_route(
db: DatabaseGuard,
body: Ruma<aliases::v3::IncomingRequest>, body: Ruma<aliases::v3::IncomingRequest>,
) -> Result<aliases::v3::Response> { ) -> Result<aliases::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !db.rooms.is_joined(sender_user, &body.room_id)? { if !services().rooms.is_joined(sender_user, &body.room_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
@ -460,7 +446,7 @@ pub async fn get_room_aliases_route(
} }
Ok(aliases::v3::Response { Ok(aliases::v3::Response {
aliases: db aliases: services()
.rooms .rooms
.room_aliases(&body.room_id) .room_aliases(&body.room_id)
.filter_map(|a| a.ok()) .filter_map(|a| a.ok())
@ -479,12 +465,11 @@ pub async fn get_room_aliases_route(
/// - Moves local aliases /// - Moves local aliases
/// - Modifies old room power levels to prevent users from speaking /// - Modifies old room power levels to prevent users from speaking
pub async fn upgrade_room_route( pub async fn upgrade_room_route(
db: DatabaseGuard,
body: Ruma<upgrade_room::v3::IncomingRequest>, body: Ruma<upgrade_room::v3::IncomingRequest>,
) -> Result<upgrade_room::v3::Response> { ) -> Result<upgrade_room::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !db.rooms.is_supported_version(&db, &body.new_version) { if !services().rooms.is_supported_version(&body.new_version) {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::UnsupportedRoomVersion, ErrorKind::UnsupportedRoomVersion,
"This server does not support that room version.", "This server does not support that room version.",
@ -492,12 +477,12 @@ pub async fn upgrade_room_route(
} }
// Create a replacement room // Create a replacement room
let replacement_room = RoomId::new(db.globals.server_name()); let replacement_room = RoomId::new(services().globals.server_name());
db.rooms services().rooms
.get_or_create_shortroomid(&replacement_room, &db.globals)?; .get_or_create_shortroomid(&replacement_room)?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -508,7 +493,7 @@ pub async fn upgrade_room_route(
// Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further // Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further
// Fail if the sender does not have the required permissions // Fail if the sender does not have the required permissions
let tombstone_event_id = db.rooms.build_and_append_pdu( let tombstone_event_id = services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomTombstone, event_type: RoomEventType::RoomTombstone,
content: to_raw_value(&RoomTombstoneEventContent { content: to_raw_value(&RoomTombstoneEventContent {
@ -522,14 +507,13 @@ pub async fn upgrade_room_route(
}, },
sender_user, sender_user,
&body.room_id, &body.room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// Change lock to replacement room // Change lock to replacement room
drop(state_lock); drop(state_lock);
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -540,7 +524,7 @@ pub async fn upgrade_room_route(
// Get the old room creation event // Get the old room creation event
let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>( let mut create_event_content = serde_json::from_str::<CanonicalJsonObject>(
db.rooms services().rooms
.room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")?
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))?
.content .content
@ -588,7 +572,7 @@ pub async fn upgrade_room_route(
)); ));
} }
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomCreate, event_type: RoomEventType::RoomCreate,
content: to_raw_value(&create_event_content) content: to_raw_value(&create_event_content)
@ -599,21 +583,20 @@ pub async fn upgrade_room_route(
}, },
sender_user, sender_user,
&replacement_room, &replacement_room,
&db,
&state_lock, &state_lock,
)?; )?;
// Join the new room // Join the new room
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomMember, event_type: RoomEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
membership: MembershipState::Join, membership: MembershipState::Join,
displayname: db.users.displayname(sender_user)?, displayname: services().users.displayname(sender_user)?,
avatar_url: db.users.avatar_url(sender_user)?, avatar_url: services().users.avatar_url(sender_user)?,
is_direct: None, is_direct: None,
third_party_invite: None, third_party_invite: None,
blurhash: db.users.blurhash(sender_user)?, blurhash: services().users.blurhash(sender_user)?,
reason: None, reason: None,
join_authorized_via_users_server: None, join_authorized_via_users_server: None,
}) })
@ -624,7 +607,6 @@ pub async fn upgrade_room_route(
}, },
sender_user, sender_user,
&replacement_room, &replacement_room,
&db,
&state_lock, &state_lock,
)?; )?;
@ -643,12 +625,12 @@ pub async fn upgrade_room_route(
// Replicate transferable state events to the new room // Replicate transferable state events to the new room
for event_type in transferable_state_events { for event_type in transferable_state_events {
let event_content = match db.rooms.room_state_get(&body.room_id, &event_type, "")? { let event_content = match services().rooms.room_state_get(&body.room_id, &event_type, "")? {
Some(v) => v.content.clone(), Some(v) => v.content.clone(),
None => continue, // Skipping missing events. None => continue, // Skipping missing events.
}; };
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: event_type.to_string().into(), event_type: event_type.to_string().into(),
content: event_content, content: event_content,
@ -658,20 +640,19 @@ pub async fn upgrade_room_route(
}, },
sender_user, sender_user,
&replacement_room, &replacement_room,
&db,
&state_lock, &state_lock,
)?; )?;
} }
// Moves any local aliases to the new room // Moves any local aliases to the new room
for alias in db.rooms.room_aliases(&body.room_id).filter_map(|r| r.ok()) { for alias in services().rooms.room_aliases(&body.room_id).filter_map(|r| r.ok()) {
db.rooms services().rooms
.set_alias(&alias, Some(&replacement_room), &db.globals)?; .set_alias(&alias, Some(&replacement_room))?;
} }
// Get the old room power levels // Get the old room power levels
let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str(
db.rooms services().rooms
.room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")?
.ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))?
.content .content
@ -685,7 +666,7 @@ pub async fn upgrade_room_route(
power_levels_event_content.invite = new_level; power_levels_event_content.invite = new_level;
// Modify the power levels in the old room to prevent sending of events and inviting new users // Modify the power levels in the old room to prevent sending of events and inviting new users
let _ = db.rooms.build_and_append_pdu( let _ = services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomPowerLevels, event_type: RoomEventType::RoomPowerLevels,
content: to_raw_value(&power_levels_event_content) content: to_raw_value(&power_levels_event_content)
@ -696,35 +677,12 @@ pub async fn upgrade_room_route(
}, },
sender_user, sender_user,
&body.room_id, &body.room_id,
&db,
&state_lock, &state_lock,
)?; )?;
drop(state_lock); drop(state_lock);
db.flush()?;
// Return the replacement room id // Return the replacement room id
Ok(upgrade_room::v3::Response { replacement_room }) Ok(upgrade_room::v3::Response { replacement_room })
} }
/// Returns the room's version.
#[tracing::instrument(skip(self))]
pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
let create_event = self.room_state_get(room_id, &StateEventType::RoomCreate, "")?;
let create_event_content: Option<RoomCreateEventContent> = create_event
.as_ref()
.map(|create_event| {
serde_json::from_str(create_event.content.get()).map_err(|e| {
warn!("Invalid create event: {}", e);
Error::bad_database("Invalid create event in db.")
})
})
.transpose()?;
let room_version = create_event_content
.map(|create_event| create_event.room_version)
.ok_or_else(|| Error::BadDatabase("Invalid room version"))?;
Ok(room_version)
}

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Error, Result, Ruma}; use crate::{Error, Result, Ruma, services};
use ruma::api::client::{ use ruma::api::client::{
error::ErrorKind, error::ErrorKind,
search::search_events::{ search::search_events::{
@ -15,7 +15,6 @@ use std::collections::BTreeMap;
/// ///
/// - Only works if the user is currently joined to the room (TODO: Respect history visibility) /// - Only works if the user is currently joined to the room (TODO: Respect history visibility)
pub async fn search_events_route( pub async fn search_events_route(
db: DatabaseGuard,
body: Ruma<search_events::v3::IncomingRequest>, body: Ruma<search_events::v3::IncomingRequest>,
) -> Result<search_events::v3::Response> { ) -> Result<search_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -24,7 +23,7 @@ pub async fn search_events_route(
let filter = &search_criteria.filter; let filter = &search_criteria.filter;
let room_ids = filter.rooms.clone().unwrap_or_else(|| { let room_ids = filter.rooms.clone().unwrap_or_else(|| {
db.rooms services().rooms
.rooms_joined(sender_user) .rooms_joined(sender_user)
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.collect() .collect()
@ -35,14 +34,14 @@ pub async fn search_events_route(
let mut searches = Vec::new(); let mut searches = Vec::new();
for room_id in room_ids { for room_id in room_ids {
if !db.rooms.is_joined(sender_user, &room_id)? { if !services().rooms.is_joined(sender_user, &room_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
)); ));
} }
if let Some(search) = db if let Some(search) = services()
.rooms .rooms
.search_pdus(&room_id, &search_criteria.search_term)? .search_pdus(&room_id, &search_criteria.search_term)?
{ {
@ -85,7 +84,7 @@ pub async fn search_events_route(
start: None, start: None,
}, },
rank: None, rank: None,
result: db result: services()
.rooms .rooms
.get_pdu_from_id(result)? .get_pdu_from_id(result)?
.map(|pdu| pdu.to_room_event()), .map(|pdu| pdu.to_room_event()),

View file

@ -1,5 +1,5 @@
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; use crate::{utils, Error, Result, Ruma, services};
use ruma::{ use ruma::{
api::client::{ api::client::{
error::ErrorKind, error::ErrorKind,
@ -41,7 +41,6 @@ pub async fn get_login_types_route(
/// Note: You can use [`GET /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see /// Note: You can use [`GET /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see
/// supported login types. /// supported login types.
pub async fn login_route( pub async fn login_route(
db: DatabaseGuard,
body: Ruma<login::v3::IncomingRequest>, body: Ruma<login::v3::IncomingRequest>,
) -> Result<login::v3::Response> { ) -> Result<login::v3::Response> {
// Validate login method // Validate login method
@ -57,11 +56,11 @@ pub async fn login_route(
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
}; };
let user_id = let user_id =
UserId::parse_with_server_name(username.to_owned(), db.globals.server_name()) UserId::parse_with_server_name(username.to_owned(), services().globals.server_name())
.map_err(|_| { .map_err(|_| {
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
})?; })?;
let hash = db.users.password_hash(&user_id)?.ok_or(Error::BadRequest( let hash = services().users.password_hash(&user_id)?.ok_or(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"Wrong username or password.", "Wrong username or password.",
))?; ))?;
@ -85,7 +84,7 @@ pub async fn login_route(
user_id user_id
} }
login::v3::IncomingLoginInfo::Token(login::v3::IncomingToken { token }) => { login::v3::IncomingLoginInfo::Token(login::v3::IncomingToken { token }) => {
if let Some(jwt_decoding_key) = db.globals.jwt_decoding_key() { if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
let token = jsonwebtoken::decode::<Claims>( let token = jsonwebtoken::decode::<Claims>(
token, token,
jwt_decoding_key, jwt_decoding_key,
@ -93,7 +92,7 @@ pub async fn login_route(
) )
.map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?; .map_err(|_| Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid."))?;
let username = token.claims.sub; let username = token.claims.sub;
UserId::parse_with_server_name(username, db.globals.server_name()).map_err( UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."), |_| Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."),
)? )?
} else { } else {
@ -122,15 +121,15 @@ pub async fn login_route(
// Determine if device_id was provided and exists in the db for this user // Determine if device_id was provided and exists in the db for this user
let device_exists = body.device_id.as_ref().map_or(false, |device_id| { let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
db.users services().users
.all_device_ids(&user_id) .all_device_ids(&user_id)
.any(|x| x.as_ref().map_or(false, |v| v == device_id)) .any(|x| x.as_ref().map_or(false, |v| v == device_id))
}); });
if device_exists { if device_exists {
db.users.set_token(&user_id, &device_id, &token)?; services().users.set_token(&user_id, &device_id, &token)?;
} else { } else {
db.users.create_device( services().users.create_device(
&user_id, &user_id,
&device_id, &device_id,
&token, &token,
@ -140,12 +139,10 @@ pub async fn login_route(
info!("{} logged in", user_id); info!("{} logged in", user_id);
db.flush()?;
Ok(login::v3::Response { Ok(login::v3::Response {
user_id, user_id,
access_token: token, access_token: token,
home_server: Some(db.globals.server_name().to_owned()), home_server: Some(services().globals.server_name().to_owned()),
device_id, device_id,
well_known: None, well_known: None,
}) })
@ -160,15 +157,12 @@ pub async fn login_route(
/// - Forgets to-device events /// - Forgets to-device events
/// - Triggers device list updates /// - Triggers device list updates
pub async fn logout_route( pub async fn logout_route(
db: DatabaseGuard,
body: Ruma<logout::v3::Request>, body: Ruma<logout::v3::Request>,
) -> Result<logout::v3::Response> { ) -> Result<logout::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_ref().expect("user is authenticated"); let sender_device = body.sender_device.as_ref().expect("user is authenticated");
db.users.remove_device(sender_user, sender_device)?; services().users.remove_device(sender_user, sender_device)?;
db.flush()?;
Ok(logout::v3::Response::new()) Ok(logout::v3::Response::new())
} }
@ -185,16 +179,13 @@ pub async fn logout_route(
/// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](fn.logout_route.html) /// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](fn.logout_route.html)
/// from each device of this user. /// from each device of this user.
pub async fn logout_all_route( pub async fn logout_all_route(
db: DatabaseGuard,
body: Ruma<logout_all::v3::Request>, body: Ruma<logout_all::v3::Request>,
) -> Result<logout_all::v3::Response> { ) -> Result<logout_all::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
for device_id in db.users.all_device_ids(sender_user).flatten() { for device_id in services().users.all_device_ids(sender_user).flatten() {
db.users.remove_device(sender_user, &device_id)?; services().users.remove_device(sender_user, &device_id)?;
} }
db.flush()?;
Ok(logout_all::v3::Response::new()) Ok(logout_all::v3::Response::new())
} }

View file

@ -1,7 +1,7 @@
use std::sync::Arc; use std::sync::Arc;
use crate::{ use crate::{
database::DatabaseGuard, pdu::PduBuilder, Database, Error, Result, Ruma, RumaResponse, Error, Result, Ruma, RumaResponse, services, service::pdu::PduBuilder,
}; };
use ruma::{ use ruma::{
api::client::{ api::client::{
@ -27,13 +27,11 @@ use ruma::{
/// - Tries to send the event into the room, auth rules will determine if it is allowed /// - Tries to send the event into the room, auth rules will determine if it is allowed
/// - If event is new canonical_alias: Rejects if alias is incorrect /// - If event is new canonical_alias: Rejects if alias is incorrect
pub async fn send_state_event_for_key_route( pub async fn send_state_event_for_key_route(
db: DatabaseGuard,
body: Ruma<send_state_event::v3::IncomingRequest>, body: Ruma<send_state_event::v3::IncomingRequest>,
) -> Result<send_state_event::v3::Response> { ) -> Result<send_state_event::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let event_id = send_state_event_for_key_helper( let event_id = send_state_event_for_key_helper(
&db,
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_type, &body.event_type,
@ -42,8 +40,6 @@ pub async fn send_state_event_for_key_route(
) )
.await?; .await?;
db.flush()?;
let event_id = (*event_id).to_owned(); let event_id = (*event_id).to_owned();
Ok(send_state_event::v3::Response { event_id }) Ok(send_state_event::v3::Response { event_id })
} }
@ -56,13 +52,12 @@ pub async fn send_state_event_for_key_route(
/// - Tries to send the event into the room, auth rules will determine if it is allowed /// - Tries to send the event into the room, auth rules will determine if it is allowed
/// - If event is new canonical_alias: Rejects if alias is incorrect /// - If event is new canonical_alias: Rejects if alias is incorrect
pub async fn send_state_event_for_empty_key_route( pub async fn send_state_event_for_empty_key_route(
db: DatabaseGuard,
body: Ruma<send_state_event::v3::IncomingRequest>, body: Ruma<send_state_event::v3::IncomingRequest>,
) -> Result<RumaResponse<send_state_event::v3::Response>> { ) -> Result<RumaResponse<send_state_event::v3::Response>> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
// Forbid m.room.encryption if encryption is disabled // Forbid m.room.encryption if encryption is disabled
if body.event_type == StateEventType::RoomEncryption && !db.globals.allow_encryption() { if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"Encryption has been disabled", "Encryption has been disabled",
@ -70,7 +65,6 @@ pub async fn send_state_event_for_empty_key_route(
} }
let event_id = send_state_event_for_key_helper( let event_id = send_state_event_for_key_helper(
&db,
sender_user, sender_user,
&body.room_id, &body.room_id,
&body.event_type.to_string().into(), &body.event_type.to_string().into(),
@ -79,8 +73,6 @@ pub async fn send_state_event_for_empty_key_route(
) )
.await?; .await?;
db.flush()?;
let event_id = (*event_id).to_owned(); let event_id = (*event_id).to_owned();
Ok(send_state_event::v3::Response { event_id }.into()) Ok(send_state_event::v3::Response { event_id }.into())
} }
@ -91,7 +83,6 @@ pub async fn send_state_event_for_empty_key_route(
/// ///
/// - If not joined: Only works if current room history visibility is world readable /// - If not joined: Only works if current room history visibility is world readable
pub async fn get_state_events_route( pub async fn get_state_events_route(
db: DatabaseGuard,
body: Ruma<get_state_events::v3::IncomingRequest>, body: Ruma<get_state_events::v3::IncomingRequest>,
) -> Result<get_state_events::v3::Response> { ) -> Result<get_state_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -99,9 +90,9 @@ pub async fn get_state_events_route(
#[allow(clippy::blocks_in_if_conditions)] #[allow(clippy::blocks_in_if_conditions)]
// Users not in the room should not be able to access the state unless history_visibility is // Users not in the room should not be able to access the state unless history_visibility is
// WorldReadable // WorldReadable
if !db.rooms.is_joined(sender_user, &body.room_id)? if !services().rooms.is_joined(sender_user, &body.room_id)?
&& !matches!( && !matches!(
db.rooms services().rooms
.room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")?
.map(|event| { .map(|event| {
serde_json::from_str(event.content.get()) serde_json::from_str(event.content.get())
@ -122,7 +113,7 @@ pub async fn get_state_events_route(
} }
Ok(get_state_events::v3::Response { Ok(get_state_events::v3::Response {
room_state: db room_state: services()
.rooms .rooms
.room_state_full(&body.room_id) .room_state_full(&body.room_id)
.await? .await?
@ -138,7 +129,6 @@ pub async fn get_state_events_route(
/// ///
/// - If not joined: Only works if current room history visibility is world readable /// - If not joined: Only works if current room history visibility is world readable
pub async fn get_state_events_for_key_route( pub async fn get_state_events_for_key_route(
db: DatabaseGuard,
body: Ruma<get_state_events_for_key::v3::IncomingRequest>, body: Ruma<get_state_events_for_key::v3::IncomingRequest>,
) -> Result<get_state_events_for_key::v3::Response> { ) -> Result<get_state_events_for_key::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -146,9 +136,9 @@ pub async fn get_state_events_for_key_route(
#[allow(clippy::blocks_in_if_conditions)] #[allow(clippy::blocks_in_if_conditions)]
// Users not in the room should not be able to access the state unless history_visibility is // Users not in the room should not be able to access the state unless history_visibility is
// WorldReadable // WorldReadable
if !db.rooms.is_joined(sender_user, &body.room_id)? if !services().rooms.is_joined(sender_user, &body.room_id)?
&& !matches!( && !matches!(
db.rooms services().rooms
.room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")?
.map(|event| { .map(|event| {
serde_json::from_str(event.content.get()) serde_json::from_str(event.content.get())
@ -168,7 +158,7 @@ pub async fn get_state_events_for_key_route(
)); ));
} }
let event = db let event = services()
.rooms .rooms
.room_state_get(&body.room_id, &body.event_type, &body.state_key)? .room_state_get(&body.room_id, &body.event_type, &body.state_key)?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
@ -188,7 +178,6 @@ pub async fn get_state_events_for_key_route(
/// ///
/// - If not joined: Only works if current room history visibility is world readable /// - If not joined: Only works if current room history visibility is world readable
pub async fn get_state_events_for_empty_key_route( pub async fn get_state_events_for_empty_key_route(
db: DatabaseGuard,
body: Ruma<get_state_events_for_key::v3::IncomingRequest>, body: Ruma<get_state_events_for_key::v3::IncomingRequest>,
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> { ) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
@ -196,9 +185,9 @@ pub async fn get_state_events_for_empty_key_route(
#[allow(clippy::blocks_in_if_conditions)] #[allow(clippy::blocks_in_if_conditions)]
// Users not in the room should not be able to access the state unless history_visibility is // Users not in the room should not be able to access the state unless history_visibility is
// WorldReadable // WorldReadable
if !db.rooms.is_joined(sender_user, &body.room_id)? if !services().rooms.is_joined(sender_user, &body.room_id)?
&& !matches!( && !matches!(
db.rooms services().rooms
.room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")? .room_state_get(&body.room_id, &StateEventType::RoomHistoryVisibility, "")?
.map(|event| { .map(|event| {
serde_json::from_str(event.content.get()) serde_json::from_str(event.content.get())
@ -218,7 +207,7 @@ pub async fn get_state_events_for_empty_key_route(
)); ));
} }
let event = db let event = services()
.rooms .rooms
.room_state_get(&body.room_id, &body.event_type, "")? .room_state_get(&body.room_id, &body.event_type, "")?
.ok_or(Error::BadRequest( .ok_or(Error::BadRequest(
@ -234,7 +223,6 @@ pub async fn get_state_events_for_empty_key_route(
} }
async fn send_state_event_for_key_helper( async fn send_state_event_for_key_helper(
db: &Database,
sender: &UserId, sender: &UserId,
room_id: &RoomId, room_id: &RoomId,
event_type: &StateEventType, event_type: &StateEventType,
@ -255,8 +243,8 @@ async fn send_state_event_for_key_helper(
} }
for alias in aliases { for alias in aliases {
if alias.server_name() != db.globals.server_name() if alias.server_name() != services().globals.server_name()
|| db || services()
.rooms .rooms
.id_from_alias(&alias)? .id_from_alias(&alias)?
.filter(|room| room == room_id) // Make sure it's the right room .filter(|room| room == room_id) // Make sure it's the right room
@ -272,7 +260,7 @@ async fn send_state_event_for_key_helper(
} }
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -281,7 +269,7 @@ async fn send_state_event_for_key_helper(
); );
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
let event_id = db.rooms.build_and_append_pdu( let event_id = services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: event_type.to_string().into(), event_type: event_type.to_string().into(),
content: serde_json::from_str(json.json().get()).expect("content is valid json"), content: serde_json::from_str(json.json().get()).expect("content is valid json"),
@ -291,7 +279,6 @@ async fn send_state_event_for_key_helper(
}, },
sender_user, sender_user,
room_id, room_id,
db,
&state_lock, &state_lock,
)?; )?;

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Database, Error, Result, Ruma, RumaResponse}; use crate::{Error, Result, Ruma, RumaResponse, services};
use ruma::{ use ruma::{
api::client::{ api::client::{
filter::{IncomingFilterDefinition, LazyLoadOptions}, filter::{IncomingFilterDefinition, LazyLoadOptions},
@ -55,16 +55,13 @@ use tracing::error;
/// - Sync is handled in an async task, multiple requests from the same device with the same /// - Sync is handled in an async task, multiple requests from the same device with the same
/// `since` will be cached /// `since` will be cached
pub async fn sync_events_route( pub async fn sync_events_route(
db: DatabaseGuard,
body: Ruma<sync_events::v3::IncomingRequest>, body: Ruma<sync_events::v3::IncomingRequest>,
) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> { ) -> Result<sync_events::v3::Response, RumaResponse<UiaaResponse>> {
let sender_user = body.sender_user.expect("user is authenticated"); let sender_user = body.sender_user.expect("user is authenticated");
let sender_device = body.sender_device.expect("user is authenticated"); let sender_device = body.sender_device.expect("user is authenticated");
let body = body.body; let body = body.body;
let arc_db = Arc::new(db); let mut rx = match services()
let mut rx = match arc_db
.globals .globals
.sync_receivers .sync_receivers
.write() .write()
@ -77,7 +74,6 @@ pub async fn sync_events_route(
v.insert((body.since.to_owned(), rx.clone())); v.insert((body.since.to_owned(), rx.clone()));
tokio::spawn(sync_helper_wrapper( tokio::spawn(sync_helper_wrapper(
Arc::clone(&arc_db),
sender_user.clone(), sender_user.clone(),
sender_device.clone(), sender_device.clone(),
body, body,
@ -93,7 +89,6 @@ pub async fn sync_events_route(
o.insert((body.since.clone(), rx.clone())); o.insert((body.since.clone(), rx.clone()));
tokio::spawn(sync_helper_wrapper( tokio::spawn(sync_helper_wrapper(
Arc::clone(&arc_db),
sender_user.clone(), sender_user.clone(),
sender_device.clone(), sender_device.clone(),
body, body,
@ -127,7 +122,6 @@ pub async fn sync_events_route(
} }
async fn sync_helper_wrapper( async fn sync_helper_wrapper(
db: Arc<DatabaseGuard>,
sender_user: Box<UserId>, sender_user: Box<UserId>,
sender_device: Box<DeviceId>, sender_device: Box<DeviceId>,
body: sync_events::v3::IncomingRequest, body: sync_events::v3::IncomingRequest,
@ -136,7 +130,6 @@ async fn sync_helper_wrapper(
let since = body.since.clone(); let since = body.since.clone();
let r = sync_helper( let r = sync_helper(
Arc::clone(&db),
sender_user.clone(), sender_user.clone(),
sender_device.clone(), sender_device.clone(),
body, body,
@ -145,7 +138,7 @@ async fn sync_helper_wrapper(
if let Ok((_, caching_allowed)) = r { if let Ok((_, caching_allowed)) = r {
if !caching_allowed { if !caching_allowed {
match db match services()
.globals .globals
.sync_receivers .sync_receivers
.write() .write()
@ -163,13 +156,10 @@ async fn sync_helper_wrapper(
} }
} }
drop(db);
let _ = tx.send(Some(r.map(|(r, _)| r))); let _ = tx.send(Some(r.map(|(r, _)| r)));
} }
async fn sync_helper( async fn sync_helper(
db: Arc<DatabaseGuard>,
sender_user: Box<UserId>, sender_user: Box<UserId>,
sender_device: Box<DeviceId>, sender_device: Box<DeviceId>,
body: sync_events::v3::IncomingRequest, body: sync_events::v3::IncomingRequest,
@ -182,19 +172,19 @@ async fn sync_helper(
}; };
// TODO: match body.set_presence { // TODO: match body.set_presence {
db.rooms.edus.ping_presence(&sender_user)?; services().rooms.edus.ping_presence(&sender_user)?;
// Setup watchers, so if there's no response, we can wait for them // Setup watchers, so if there's no response, we can wait for them
let watcher = db.watch(&sender_user, &sender_device); let watcher = services().watch(&sender_user, &sender_device);
let next_batch = db.globals.current_count()?; let next_batch = services().globals.current_count()?;
let next_batch_string = next_batch.to_string(); let next_batch_string = next_batch.to_string();
// Load filter // Load filter
let filter = match body.filter { let filter = match body.filter {
None => IncomingFilterDefinition::default(), None => IncomingFilterDefinition::default(),
Some(IncomingFilter::FilterDefinition(filter)) => filter, Some(IncomingFilter::FilterDefinition(filter)) => filter,
Some(IncomingFilter::FilterId(filter_id)) => db Some(IncomingFilter::FilterId(filter_id)) => services()
.users .users
.get_filter(&sender_user, &filter_id)? .get_filter(&sender_user, &filter_id)?
.unwrap_or_default(), .unwrap_or_default(),
@ -221,12 +211,12 @@ async fn sync_helper(
// Look for device list updates of this account // Look for device list updates of this account
device_list_updates.extend( device_list_updates.extend(
db.users services().users
.keys_changed(&sender_user.to_string(), since, None) .keys_changed(&sender_user.to_string(), since, None)
.filter_map(|r| r.ok()), .filter_map(|r| r.ok()),
); );
let all_joined_rooms = db.rooms.rooms_joined(&sender_user).collect::<Vec<_>>(); let all_joined_rooms = services().rooms.rooms_joined(&sender_user).collect::<Vec<_>>();
for room_id in all_joined_rooms { for room_id in all_joined_rooms {
let room_id = room_id?; let room_id = room_id?;
@ -234,7 +224,7 @@ async fn sync_helper(
// Get and drop the lock to wait for remaining operations to finish // Get and drop the lock to wait for remaining operations to finish
// This will make sure the we have all events until next_batch // This will make sure the we have all events until next_batch
let mutex_insert = Arc::clone( let mutex_insert = Arc::clone(
db.globals services().globals
.roomid_mutex_insert .roomid_mutex_insert
.write() .write()
.unwrap() .unwrap()
@ -247,8 +237,8 @@ async fn sync_helper(
let timeline_pdus; let timeline_pdus;
let limited; let limited;
if db.rooms.last_timeline_count(&sender_user, &room_id)? > since { if services().rooms.last_timeline_count(&sender_user, &room_id)? > since {
let mut non_timeline_pdus = db let mut non_timeline_pdus = services()
.rooms .rooms
.pdus_until(&sender_user, &room_id, u64::MAX)? .pdus_until(&sender_user, &room_id, u64::MAX)?
.filter_map(|r| { .filter_map(|r| {
@ -259,7 +249,7 @@ async fn sync_helper(
r.ok() r.ok()
}) })
.take_while(|(pduid, _)| { .take_while(|(pduid, _)| {
db.rooms services().rooms
.pdu_count(pduid) .pdu_count(pduid)
.map_or(false, |count| count > since) .map_or(false, |count| count > since)
}); });
@ -282,7 +272,7 @@ async fn sync_helper(
} }
let send_notification_counts = !timeline_pdus.is_empty() let send_notification_counts = !timeline_pdus.is_empty()
|| db || services()
.rooms .rooms
.edus .edus
.last_privateread_update(&sender_user, &room_id)? .last_privateread_update(&sender_user, &room_id)?
@ -293,24 +283,24 @@ async fn sync_helper(
timeline_users.insert(event.sender.as_str().to_owned()); timeline_users.insert(event.sender.as_str().to_owned());
} }
db.rooms services().rooms
.lazy_load_confirm_delivery(&sender_user, &sender_device, &room_id, since)?; .lazy_load_confirm_delivery(&sender_user, &sender_device, &room_id, since)?;
// Database queries: // Database queries:
let current_shortstatehash = if let Some(s) = db.rooms.current_shortstatehash(&room_id)? { let current_shortstatehash = if let Some(s) = services().rooms.current_shortstatehash(&room_id)? {
s s
} else { } else {
error!("Room {} has no state", room_id); error!("Room {} has no state", room_id);
continue; continue;
}; };
let since_shortstatehash = db.rooms.get_token_shortstatehash(&room_id, since)?; let since_shortstatehash = services().rooms.get_token_shortstatehash(&room_id, since)?;
// Calculates joined_member_count, invited_member_count and heroes // Calculates joined_member_count, invited_member_count and heroes
let calculate_counts = || { let calculate_counts = || {
let joined_member_count = db.rooms.room_joined_count(&room_id)?.unwrap_or(0); let joined_member_count = services().rooms.room_joined_count(&room_id)?.unwrap_or(0);
let invited_member_count = db.rooms.room_invited_count(&room_id)?.unwrap_or(0); let invited_member_count = services().rooms.room_invited_count(&room_id)?.unwrap_or(0);
// Recalculate heroes (first 5 members) // Recalculate heroes (first 5 members)
let mut heroes = Vec::new(); let mut heroes = Vec::new();
@ -319,7 +309,7 @@ async fn sync_helper(
// Go through all PDUs and for each member event, check if the user is still joined or // Go through all PDUs and for each member event, check if the user is still joined or
// invited until we have 5 or we reach the end // invited until we have 5 or we reach the end
for hero in db for hero in services()
.rooms .rooms
.all_pdus(&sender_user, &room_id)? .all_pdus(&sender_user, &room_id)?
.filter_map(|pdu| pdu.ok()) // Ignore all broken pdus .filter_map(|pdu| pdu.ok()) // Ignore all broken pdus
@ -339,8 +329,8 @@ async fn sync_helper(
if matches!( if matches!(
content.membership, content.membership,
MembershipState::Join | MembershipState::Invite MembershipState::Join | MembershipState::Invite
) && (db.rooms.is_joined(&user_id, &room_id)? ) && (services().rooms.is_joined(&user_id, &room_id)?
|| db.rooms.is_invited(&user_id, &room_id)?) || services().rooms.is_invited(&user_id, &room_id)?)
{ {
Ok::<_, Error>(Some(state_key.clone())) Ok::<_, Error>(Some(state_key.clone()))
} else { } else {
@ -381,17 +371,17 @@ async fn sync_helper(
let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; let (joined_member_count, invited_member_count, heroes) = calculate_counts()?;
let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; let current_state_ids = services().rooms.state_full_ids(current_shortstatehash).await?;
let mut state_events = Vec::new(); let mut state_events = Vec::new();
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
let mut i = 0; let mut i = 0;
for (shortstatekey, id) in current_state_ids { for (shortstatekey, id) in current_state_ids {
let (event_type, state_key) = db.rooms.get_statekey_from_short(shortstatekey)?; let (event_type, state_key) = services().rooms.get_statekey_from_short(shortstatekey)?;
if event_type != StateEventType::RoomMember { if event_type != StateEventType::RoomMember {
let pdu = match db.rooms.get_pdu(&id)? { let pdu = match services().rooms.get_pdu(&id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
None => { None => {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
@ -408,7 +398,7 @@ async fn sync_helper(
|| body.full_state || body.full_state
|| timeline_users.contains(&state_key) || timeline_users.contains(&state_key)
{ {
let pdu = match db.rooms.get_pdu(&id)? { let pdu = match services().rooms.get_pdu(&id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
None => { None => {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
@ -430,12 +420,12 @@ async fn sync_helper(
} }
// Reset lazy loading because this is an initial sync // Reset lazy loading because this is an initial sync
db.rooms services().rooms
.lazy_load_reset(&sender_user, &sender_device, &room_id)?; .lazy_load_reset(&sender_user, &sender_device, &room_id)?;
// The state_events above should contain all timeline_users, let's mark them as lazy // The state_events above should contain all timeline_users, let's mark them as lazy
// loaded. // loaded.
db.rooms.lazy_load_mark_sent( services().rooms.lazy_load_mark_sent(
&sender_user, &sender_user,
&sender_device, &sender_device,
&room_id, &room_id,
@ -457,7 +447,7 @@ async fn sync_helper(
// Incremental /sync // Incremental /sync
let since_shortstatehash = since_shortstatehash.unwrap(); let since_shortstatehash = since_shortstatehash.unwrap();
let since_sender_member: Option<RoomMemberEventContent> = db let since_sender_member: Option<RoomMemberEventContent> = services()
.rooms .rooms
.state_get( .state_get(
since_shortstatehash, since_shortstatehash,
@ -477,12 +467,12 @@ async fn sync_helper(
let mut lazy_loaded = HashSet::new(); let mut lazy_loaded = HashSet::new();
if since_shortstatehash != current_shortstatehash { if since_shortstatehash != current_shortstatehash {
let current_state_ids = db.rooms.state_full_ids(current_shortstatehash).await?; let current_state_ids = services().rooms.state_full_ids(current_shortstatehash).await?;
let since_state_ids = db.rooms.state_full_ids(since_shortstatehash).await?; let since_state_ids = services().rooms.state_full_ids(since_shortstatehash).await?;
for (key, id) in current_state_ids { for (key, id) in current_state_ids {
if body.full_state || since_state_ids.get(&key) != Some(&id) { if body.full_state || since_state_ids.get(&key) != Some(&id) {
let pdu = match db.rooms.get_pdu(&id)? { let pdu = match services().rooms.get_pdu(&id)? {
Some(pdu) => pdu, Some(pdu) => pdu,
None => { None => {
error!("Pdu in state not found: {}", id); error!("Pdu in state not found: {}", id);
@ -515,14 +505,14 @@ async fn sync_helper(
continue; continue;
} }
if !db.rooms.lazy_load_was_sent_before( if !services().rooms.lazy_load_was_sent_before(
&sender_user, &sender_user,
&sender_device, &sender_device,
&room_id, &room_id,
&event.sender, &event.sender,
)? || lazy_load_send_redundant )? || lazy_load_send_redundant
{ {
if let Some(member_event) = db.rooms.room_state_get( if let Some(member_event) = services().rooms.room_state_get(
&room_id, &room_id,
&StateEventType::RoomMember, &StateEventType::RoomMember,
event.sender.as_str(), event.sender.as_str(),
@ -533,7 +523,7 @@ async fn sync_helper(
} }
} }
db.rooms.lazy_load_mark_sent( services().rooms.lazy_load_mark_sent(
&sender_user, &sender_user,
&sender_device, &sender_device,
&room_id, &room_id,
@ -541,13 +531,13 @@ async fn sync_helper(
next_batch, next_batch,
); );
let encrypted_room = db let encrypted_room = services()
.rooms .rooms
.state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")?
.is_some(); .is_some();
let since_encryption = let since_encryption =
db.rooms services().rooms
.state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")?; .state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")?;
// Calculations: // Calculations:
@ -580,7 +570,7 @@ async fn sync_helper(
match new_membership { match new_membership {
MembershipState::Join => { MembershipState::Join => {
// A new user joined an encrypted room // A new user joined an encrypted room
if !share_encrypted_room(&db, &sender_user, &user_id, &room_id)? { if !share_encrypted_room(&sender_user, &user_id, &room_id)? {
device_list_updates.insert(user_id); device_list_updates.insert(user_id);
} }
} }
@ -597,7 +587,7 @@ async fn sync_helper(
if joined_since_last_sync && encrypted_room || new_encrypted_room { if joined_since_last_sync && encrypted_room || new_encrypted_room {
// If the user is in a new encrypted room, give them all joined users // If the user is in a new encrypted room, give them all joined users
device_list_updates.extend( device_list_updates.extend(
db.rooms services().rooms
.room_members(&room_id) .room_members(&room_id)
.flatten() .flatten()
.filter(|user_id| { .filter(|user_id| {
@ -606,7 +596,7 @@ async fn sync_helper(
}) })
.filter(|user_id| { .filter(|user_id| {
// Only send keys if the sender doesn't share an encrypted room with the target already // Only send keys if the sender doesn't share an encrypted room with the target already
!share_encrypted_room(&db, &sender_user, user_id, &room_id) !share_encrypted_room(&sender_user, user_id, &room_id)
.unwrap_or(false) .unwrap_or(false)
}), }),
); );
@ -629,14 +619,14 @@ async fn sync_helper(
// Look for device list updates in this room // Look for device list updates in this room
device_list_updates.extend( device_list_updates.extend(
db.users services().users
.keys_changed(&room_id.to_string(), since, None) .keys_changed(&room_id.to_string(), since, None)
.filter_map(|r| r.ok()), .filter_map(|r| r.ok()),
); );
let notification_count = if send_notification_counts { let notification_count = if send_notification_counts {
Some( Some(
db.rooms services().rooms
.notification_count(&sender_user, &room_id)? .notification_count(&sender_user, &room_id)?
.try_into() .try_into()
.expect("notification count can't go that high"), .expect("notification count can't go that high"),
@ -647,7 +637,7 @@ async fn sync_helper(
let highlight_count = if send_notification_counts { let highlight_count = if send_notification_counts {
Some( Some(
db.rooms services().rooms
.highlight_count(&sender_user, &room_id)? .highlight_count(&sender_user, &room_id)?
.try_into() .try_into()
.expect("highlight count can't go that high"), .expect("highlight count can't go that high"),
@ -659,7 +649,7 @@ async fn sync_helper(
let prev_batch = timeline_pdus let prev_batch = timeline_pdus
.first() .first()
.map_or(Ok::<_, Error>(None), |(pdu_id, _)| { .map_or(Ok::<_, Error>(None), |(pdu_id, _)| {
Ok(Some(db.rooms.pdu_count(pdu_id)?.to_string())) Ok(Some(services().rooms.pdu_count(pdu_id)?.to_string()))
})?; })?;
let room_events: Vec<_> = timeline_pdus let room_events: Vec<_> = timeline_pdus
@ -667,7 +657,7 @@ async fn sync_helper(
.map(|(_, pdu)| pdu.to_sync_room_event()) .map(|(_, pdu)| pdu.to_sync_room_event())
.collect(); .collect();
let mut edus: Vec<_> = db let mut edus: Vec<_> = services()
.rooms .rooms
.edus .edus
.readreceipts_since(&room_id, since) .readreceipts_since(&room_id, since)
@ -675,10 +665,10 @@ async fn sync_helper(
.map(|(_, _, v)| v) .map(|(_, _, v)| v)
.collect(); .collect();
if db.rooms.edus.last_typing_update(&room_id, &db.globals)? > since { if services().rooms.edus.last_typing_update(&room_id, &services().globals)? > since {
edus.push( edus.push(
serde_json::from_str( serde_json::from_str(
&serde_json::to_string(&db.rooms.edus.typings_all(&room_id)?) &serde_json::to_string(&services().rooms.edus.typings_all(&room_id)?)
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
) )
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
@ -686,12 +676,12 @@ async fn sync_helper(
} }
// Save the state after this sync so we can send the correct state diff next sync // Save the state after this sync so we can send the correct state diff next sync
db.rooms services().rooms
.associate_token_shortstatehash(&room_id, next_batch, current_shortstatehash)?; .associate_token_shortstatehash(&room_id, next_batch, current_shortstatehash)?;
let joined_room = JoinedRoom { let joined_room = JoinedRoom {
account_data: RoomAccountData { account_data: RoomAccountData {
events: db events: services()
.account_data .account_data
.changes_since(Some(&room_id), &sender_user, since)? .changes_since(Some(&room_id), &sender_user, since)?
.into_iter() .into_iter()
@ -731,9 +721,9 @@ async fn sync_helper(
// Take presence updates from this room // Take presence updates from this room
for (user_id, presence) in for (user_id, presence) in
db.rooms services().rooms
.edus .edus
.presence_since(&room_id, since, &db.rooms, &db.globals)? .presence_since(&room_id, since)?
{ {
match presence_updates.entry(user_id) { match presence_updates.entry(user_id) {
Entry::Vacant(v) => { Entry::Vacant(v) => {
@ -765,14 +755,14 @@ async fn sync_helper(
} }
let mut left_rooms = BTreeMap::new(); let mut left_rooms = BTreeMap::new();
let all_left_rooms: Vec<_> = db.rooms.rooms_left(&sender_user).collect(); let all_left_rooms: Vec<_> = services().rooms.rooms_left(&sender_user).collect();
for result in all_left_rooms { for result in all_left_rooms {
let (room_id, left_state_events) = result?; let (room_id, left_state_events) = result?;
{ {
// Get and drop the lock to wait for remaining operations to finish // Get and drop the lock to wait for remaining operations to finish
let mutex_insert = Arc::clone( let mutex_insert = Arc::clone(
db.globals services().globals
.roomid_mutex_insert .roomid_mutex_insert
.write() .write()
.unwrap() .unwrap()
@ -783,7 +773,7 @@ async fn sync_helper(
drop(insert_lock); drop(insert_lock);
} }
let left_count = db.rooms.get_left_count(&room_id, &sender_user)?; let left_count = services().rooms.get_left_count(&room_id, &sender_user)?;
// Left before last sync // Left before last sync
if Some(since) >= left_count { if Some(since) >= left_count {
@ -807,14 +797,14 @@ async fn sync_helper(
} }
let mut invited_rooms = BTreeMap::new(); let mut invited_rooms = BTreeMap::new();
let all_invited_rooms: Vec<_> = db.rooms.rooms_invited(&sender_user).collect(); let all_invited_rooms: Vec<_> = services().rooms.rooms_invited(&sender_user).collect();
for result in all_invited_rooms { for result in all_invited_rooms {
let (room_id, invite_state_events) = result?; let (room_id, invite_state_events) = result?;
{ {
// Get and drop the lock to wait for remaining operations to finish // Get and drop the lock to wait for remaining operations to finish
let mutex_insert = Arc::clone( let mutex_insert = Arc::clone(
db.globals services().globals
.roomid_mutex_insert .roomid_mutex_insert
.write() .write()
.unwrap() .unwrap()
@ -825,7 +815,7 @@ async fn sync_helper(
drop(insert_lock); drop(insert_lock);
} }
let invite_count = db.rooms.get_invite_count(&room_id, &sender_user)?; let invite_count = services().rooms.get_invite_count(&room_id, &sender_user)?;
// Invited before last sync // Invited before last sync
if Some(since) >= invite_count { if Some(since) >= invite_count {
@ -843,13 +833,13 @@ async fn sync_helper(
} }
for user_id in left_encrypted_users { for user_id in left_encrypted_users {
let still_share_encrypted_room = db let still_share_encrypted_room = services()
.rooms .rooms
.get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])?
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.filter_map(|other_room_id| { .filter_map(|other_room_id| {
Some( Some(
db.rooms services().rooms
.room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "")
.ok()? .ok()?
.is_some(), .is_some(),
@ -864,7 +854,7 @@ async fn sync_helper(
} }
// Remove all to-device events the device received *last time* // Remove all to-device events the device received *last time*
db.users services().users
.remove_to_device_events(&sender_user, &sender_device, since)?; .remove_to_device_events(&sender_user, &sender_device, since)?;
let response = sync_events::v3::Response { let response = sync_events::v3::Response {
@ -882,7 +872,7 @@ async fn sync_helper(
.collect(), .collect(),
}, },
account_data: GlobalAccountData { account_data: GlobalAccountData {
events: db events: services()
.account_data .account_data
.changes_since(None, &sender_user, since)? .changes_since(None, &sender_user, since)?
.into_iter() .into_iter()
@ -897,9 +887,9 @@ async fn sync_helper(
changed: device_list_updates.into_iter().collect(), changed: device_list_updates.into_iter().collect(),
left: device_list_left.into_iter().collect(), left: device_list_left.into_iter().collect(),
}, },
device_one_time_keys_count: db.users.count_one_time_keys(&sender_user, &sender_device)?, device_one_time_keys_count: services().users.count_one_time_keys(&sender_user, &sender_device)?,
to_device: ToDevice { to_device: ToDevice {
events: db events: services()
.users .users
.get_to_device_events(&sender_user, &sender_device)?, .get_to_device_events(&sender_user, &sender_device)?,
}, },
@ -928,21 +918,19 @@ async fn sync_helper(
} }
} }
#[tracing::instrument(skip(db))]
fn share_encrypted_room( fn share_encrypted_room(
db: &Database,
sender_user: &UserId, sender_user: &UserId,
user_id: &UserId, user_id: &UserId,
ignore_room: &RoomId, ignore_room: &RoomId,
) -> Result<bool> { ) -> Result<bool> {
Ok(db Ok(services()
.rooms .rooms
.get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])?
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.filter(|room_id| room_id != ignore_room) .filter(|room_id| room_id != ignore_room)
.filter_map(|other_room_id| { .filter_map(|other_room_id| {
Some( Some(
db.rooms services().rooms
.room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "")
.ok()? .ok()?
.is_some(), .is_some(),

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Result, Ruma}; use crate::{Result, Ruma, services};
use ruma::{ use ruma::{
api::client::tag::{create_tag, delete_tag, get_tags}, api::client::tag::{create_tag, delete_tag, get_tags},
events::{ events::{
@ -14,12 +14,11 @@ use std::collections::BTreeMap;
/// ///
/// - Inserts the tag into the tag event of the room account data. /// - Inserts the tag into the tag event of the room account data.
pub async fn update_tag_route( pub async fn update_tag_route(
db: DatabaseGuard,
body: Ruma<create_tag::v3::IncomingRequest>, body: Ruma<create_tag::v3::IncomingRequest>,
) -> Result<create_tag::v3::Response> { ) -> Result<create_tag::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mut tags_event = db let mut tags_event = services()
.account_data .account_data
.get( .get(
Some(&body.room_id), Some(&body.room_id),
@ -36,16 +35,13 @@ pub async fn update_tag_route(
.tags .tags
.insert(body.tag.clone().into(), body.tag_info.clone()); .insert(body.tag.clone().into(), body.tag_info.clone());
db.account_data.update( services().account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
&tags_event, &tags_event,
&db.globals,
)?; )?;
db.flush()?;
Ok(create_tag::v3::Response {}) Ok(create_tag::v3::Response {})
} }
@ -55,12 +51,11 @@ pub async fn update_tag_route(
/// ///
/// - Removes the tag from the tag event of the room account data. /// - Removes the tag from the tag event of the room account data.
pub async fn delete_tag_route( pub async fn delete_tag_route(
db: DatabaseGuard,
body: Ruma<delete_tag::v3::IncomingRequest>, body: Ruma<delete_tag::v3::IncomingRequest>,
) -> Result<delete_tag::v3::Response> { ) -> Result<delete_tag::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let mut tags_event = db let mut tags_event = services()
.account_data .account_data
.get( .get(
Some(&body.room_id), Some(&body.room_id),
@ -74,16 +69,13 @@ pub async fn delete_tag_route(
}); });
tags_event.content.tags.remove(&body.tag.clone().into()); tags_event.content.tags.remove(&body.tag.clone().into());
db.account_data.update( services().account_data.update(
Some(&body.room_id), Some(&body.room_id),
sender_user, sender_user,
RoomAccountDataEventType::Tag, RoomAccountDataEventType::Tag,
&tags_event, &tags_event,
&db.globals,
)?; )?;
db.flush()?;
Ok(delete_tag::v3::Response {}) Ok(delete_tag::v3::Response {})
} }
@ -93,13 +85,12 @@ pub async fn delete_tag_route(
/// ///
/// - Gets the tag event of the room account data. /// - Gets the tag event of the room account data.
pub async fn get_tags_route( pub async fn get_tags_route(
db: DatabaseGuard,
body: Ruma<get_tags::v3::IncomingRequest>, body: Ruma<get_tags::v3::IncomingRequest>,
) -> Result<get_tags::v3::Response> { ) -> Result<get_tags::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
Ok(get_tags::v3::Response { Ok(get_tags::v3::Response {
tags: db tags: services()
.account_data .account_data
.get( .get(
Some(&body.room_id), Some(&body.room_id),

View file

@ -1,7 +1,7 @@
use ruma::events::ToDeviceEventType; use ruma::events::ToDeviceEventType;
use std::collections::BTreeMap; use std::collections::BTreeMap;
use crate::{database::DatabaseGuard, Error, Result, Ruma}; use crate::{Error, Result, Ruma, services};
use ruma::{ use ruma::{
api::{ api::{
client::{error::ErrorKind, to_device::send_event_to_device}, client::{error::ErrorKind, to_device::send_event_to_device},
@ -14,14 +14,13 @@ use ruma::{
/// ///
/// Send a to-device event to a set of client devices. /// Send a to-device event to a set of client devices.
pub async fn send_event_to_device_route( pub async fn send_event_to_device_route(
db: DatabaseGuard,
body: Ruma<send_event_to_device::v3::IncomingRequest>, body: Ruma<send_event_to_device::v3::IncomingRequest>,
) -> Result<send_event_to_device::v3::Response> { ) -> Result<send_event_to_device::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let sender_device = body.sender_device.as_deref(); let sender_device = body.sender_device.as_deref();
// Check if this is a new transaction id // Check if this is a new transaction id
if db if services()
.transaction_ids .transaction_ids
.existing_txnid(sender_user, sender_device, &body.txn_id)? .existing_txnid(sender_user, sender_device, &body.txn_id)?
.is_some() .is_some()
@ -31,13 +30,13 @@ pub async fn send_event_to_device_route(
for (target_user_id, map) in &body.messages { for (target_user_id, map) in &body.messages {
for (target_device_id_maybe, event) in map { for (target_device_id_maybe, event) in map {
if target_user_id.server_name() != db.globals.server_name() { if target_user_id.server_name() != services().globals.server_name() {
let mut map = BTreeMap::new(); let mut map = BTreeMap::new();
map.insert(target_device_id_maybe.clone(), event.clone()); map.insert(target_device_id_maybe.clone(), event.clone());
let mut messages = BTreeMap::new(); let mut messages = BTreeMap::new();
messages.insert(target_user_id.clone(), map); messages.insert(target_user_id.clone(), map);
db.sending.send_reliable_edu( services().sending.send_reliable_edu(
target_user_id.server_name(), target_user_id.server_name(),
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(
DirectDeviceContent { DirectDeviceContent {
@ -48,14 +47,14 @@ pub async fn send_event_to_device_route(
}, },
)) ))
.expect("DirectToDevice EDU can be serialized"), .expect("DirectToDevice EDU can be serialized"),
db.globals.next_count()?, services().globals.next_count()?,
)?; )?;
continue; continue;
} }
match target_device_id_maybe { match target_device_id_maybe {
DeviceIdOrAllDevices::DeviceId(target_device_id) => db.users.add_to_device_event( DeviceIdOrAllDevices::DeviceId(target_device_id) => services().users.add_to_device_event(
sender_user, sender_user,
target_user_id, target_user_id,
&target_device_id, &target_device_id,
@ -63,12 +62,11 @@ pub async fn send_event_to_device_route(
event.deserialize_as().map_err(|_| { event.deserialize_as().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
})?, })?,
&db.globals,
)?, )?,
DeviceIdOrAllDevices::AllDevices => { DeviceIdOrAllDevices::AllDevices => {
for target_device_id in db.users.all_device_ids(target_user_id) { for target_device_id in services().users.all_device_ids(target_user_id) {
db.users.add_to_device_event( services().users.add_to_device_event(
sender_user, sender_user,
target_user_id, target_user_id,
&target_device_id?, &target_device_id?,
@ -76,7 +74,6 @@ pub async fn send_event_to_device_route(
event.deserialize_as().map_err(|_| { event.deserialize_as().map_err(|_| {
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
})?, })?,
&db.globals,
)?; )?;
} }
} }
@ -85,10 +82,8 @@ pub async fn send_event_to_device_route(
} }
// Save transaction id with empty data // Save transaction id with empty data
db.transaction_ids services().transaction_ids
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?; .add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
db.flush()?;
Ok(send_event_to_device::v3::Response {}) Ok(send_event_to_device::v3::Response {})
} }

View file

@ -1,18 +1,17 @@
use crate::{database::DatabaseGuard, utils, Error, Result, Ruma}; use crate::{utils, Error, Result, Ruma, services};
use ruma::api::client::{error::ErrorKind, typing::create_typing_event}; use ruma::api::client::{error::ErrorKind, typing::create_typing_event};
/// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}` /// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}`
/// ///
/// Sets the typing state of the sender user. /// Sets the typing state of the sender user.
pub async fn create_typing_event_route( pub async fn create_typing_event_route(
db: DatabaseGuard,
body: Ruma<create_typing_event::v3::IncomingRequest>, body: Ruma<create_typing_event::v3::IncomingRequest>,
) -> Result<create_typing_event::v3::Response> { ) -> Result<create_typing_event::v3::Response> {
use create_typing_event::v3::Typing; use create_typing_event::v3::Typing;
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
if !db.rooms.is_joined(sender_user, &body.room_id)? { if !services().rooms.is_joined(sender_user, &body.room_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You are not in this room.", "You are not in this room.",
@ -20,16 +19,15 @@ pub async fn create_typing_event_route(
} }
if let Typing::Yes(duration) = body.state { if let Typing::Yes(duration) = body.state {
db.rooms.edus.typing_add( services().rooms.edus.typing_add(
sender_user, sender_user,
&body.room_id, &body.room_id,
duration.as_millis() as u64 + utils::millis_since_unix_epoch(), duration.as_millis() as u64 + utils::millis_since_unix_epoch(),
&db.globals,
)?; )?;
} else { } else {
db.rooms services().rooms
.edus .edus
.typing_remove(sender_user, &body.room_id, &db.globals)?; .typing_remove(sender_user, &body.room_id)?;
} }
Ok(create_typing_event::v3::Response {}) Ok(create_typing_event::v3::Response {})

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Result, Ruma}; use crate::{Result, Ruma, services};
use ruma::{ use ruma::{
api::client::user_directory::search_users, api::client::user_directory::search_users,
events::{ events::{
@ -14,20 +14,19 @@ use ruma::{
/// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public) /// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public)
/// and don't share a room with the sender /// and don't share a room with the sender
pub async fn search_users_route( pub async fn search_users_route(
db: DatabaseGuard,
body: Ruma<search_users::v3::IncomingRequest>, body: Ruma<search_users::v3::IncomingRequest>,
) -> Result<search_users::v3::Response> { ) -> Result<search_users::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let limit = u64::from(body.limit) as usize; let limit = u64::from(body.limit) as usize;
let mut users = db.users.iter().filter_map(|user_id| { let mut users = services().users.iter().filter_map(|user_id| {
// Filter out buggy users (they should not exist, but you never know...) // Filter out buggy users (they should not exist, but you never know...)
let user_id = user_id.ok()?; let user_id = user_id.ok()?;
let user = search_users::v3::User { let user = search_users::v3::User {
user_id: user_id.clone(), user_id: user_id.clone(),
display_name: db.users.displayname(&user_id).ok()?, display_name: services().users.displayname(&user_id).ok()?,
avatar_url: db.users.avatar_url(&user_id).ok()?, avatar_url: services().users.avatar_url(&user_id).ok()?,
}; };
let user_id_matches = user let user_id_matches = user
@ -50,11 +49,11 @@ pub async fn search_users_route(
} }
let user_is_in_public_rooms = let user_is_in_public_rooms =
db.rooms services().rooms
.rooms_joined(&user_id) .rooms_joined(&user_id)
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.any(|room| { .any(|room| {
db.rooms services().rooms
.room_state_get(&room, &StateEventType::RoomJoinRules, "") .room_state_get(&room, &StateEventType::RoomJoinRules, "")
.map_or(false, |event| { .map_or(false, |event| {
event.map_or(false, |event| { event.map_or(false, |event| {
@ -70,7 +69,7 @@ pub async fn search_users_route(
return Some(user); return Some(user);
} }
let user_is_in_shared_rooms = db let user_is_in_shared_rooms = services()
.rooms .rooms
.get_shared_rooms(vec![sender_user.clone(), user_id.clone()]) .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])
.ok()? .ok()?

View file

@ -1,4 +1,4 @@
use crate::{database::DatabaseGuard, Result, Ruma}; use crate::{Result, Ruma, services};
use hmac::{Hmac, Mac, NewMac}; use hmac::{Hmac, Mac, NewMac};
use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch}; use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch};
use sha1::Sha1; use sha1::Sha1;
@ -10,16 +10,15 @@ type HmacSha1 = Hmac<Sha1>;
/// ///
/// TODO: Returns information about the recommended turn server. /// TODO: Returns information about the recommended turn server.
pub async fn turn_server_route( pub async fn turn_server_route(
db: DatabaseGuard,
body: Ruma<get_turn_server_info::v3::IncomingRequest>, body: Ruma<get_turn_server_info::v3::IncomingRequest>,
) -> Result<get_turn_server_info::v3::Response> { ) -> Result<get_turn_server_info::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let turn_secret = db.globals.turn_secret(); let turn_secret = services().globals.turn_secret();
let (username, password) = if !turn_secret.is_empty() { let (username, password) = if !turn_secret.is_empty() {
let expiry = SecondsSinceUnixEpoch::from_system_time( let expiry = SecondsSinceUnixEpoch::from_system_time(
SystemTime::now() + Duration::from_secs(db.globals.turn_ttl()), SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()),
) )
.expect("time is valid"); .expect("time is valid");
@ -34,15 +33,15 @@ pub async fn turn_server_route(
(username, password) (username, password)
} else { } else {
( (
db.globals.turn_username().clone(), services().globals.turn_username().clone(),
db.globals.turn_password().clone(), services().globals.turn_password().clone(),
) )
}; };
Ok(get_turn_server_info::v3::Response { Ok(get_turn_server_info::v3::Response {
username, username,
password, password,
uris: db.globals.turn_uris().to_vec(), uris: services().globals.turn_uris().to_vec(),
ttl: Duration::from_secs(db.globals.turn_ttl()), ttl: Duration::from_secs(services().globals.turn_ttl()),
}) })
} }

4
src/api/mod.rs Normal file
View file

@ -0,0 +1,4 @@
pub mod client_server;
pub mod server_server;
pub mod appservice_server;
pub mod ruma_wrapper;

View file

@ -24,7 +24,7 @@ use serde::Deserialize;
use tracing::{debug, error, warn}; use tracing::{debug, error, warn};
use super::{Ruma, RumaResponse}; use super::{Ruma, RumaResponse};
use crate::{database::DatabaseGuard, server_server, Error, Result}; use crate::{Error, Result, api::server_server, services};
#[async_trait] #[async_trait]
impl<T, B> FromRequest<B> for Ruma<T> impl<T, B> FromRequest<B> for Ruma<T>
@ -44,7 +44,6 @@ where
} }
let metadata = T::METADATA; let metadata = T::METADATA;
let db = DatabaseGuard::from_request(req).await?;
let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?; let auth_header = Option::<TypedHeader<Authorization<Bearer>>>::from_request(req).await?;
let path_params = Path::<Vec<String>>::from_request(req).await?; let path_params = Path::<Vec<String>>::from_request(req).await?;
@ -71,7 +70,7 @@ where
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok(); let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
let appservices = db.appservice.all().unwrap(); let appservices = services().appservice.all().unwrap();
let appservice_registration = appservices.iter().find(|(_id, registration)| { let appservice_registration = appservices.iter().find(|(_id, registration)| {
registration registration
.get("as_token") .get("as_token")
@ -91,14 +90,14 @@ where
.unwrap() .unwrap()
.as_str() .as_str()
.unwrap(), .unwrap(),
db.globals.server_name(), services().globals.server_name(),
) )
.unwrap() .unwrap()
}, },
|s| UserId::parse(s).unwrap(), |s| UserId::parse(s).unwrap(),
); );
if !db.users.exists(&user_id).unwrap() { if !services().users.exists(&user_id).unwrap() {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"User does not exist.", "User does not exist.",
@ -124,7 +123,7 @@ where
} }
}; };
match db.users.find_from_token(token).unwrap() { match services().users.find_from_token(token).unwrap() {
None => { None => {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::UnknownToken { soft_logout: false }, ErrorKind::UnknownToken { soft_logout: false },
@ -185,7 +184,7 @@ where
( (
"destination".to_owned(), "destination".to_owned(),
CanonicalJsonValue::String( CanonicalJsonValue::String(
db.globals.server_name().as_str().to_owned(), services().globals.server_name().as_str().to_owned(),
), ),
), ),
( (
@ -199,7 +198,6 @@ where
}; };
let keys_result = server_server::fetch_signing_keys( let keys_result = server_server::fetch_signing_keys(
&db,
&x_matrix.origin, &x_matrix.origin,
vec![x_matrix.key.to_owned()], vec![x_matrix.key.to_owned()],
) )
@ -251,7 +249,7 @@ where
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
let user_id = sender_user.clone().unwrap_or_else(|| { let user_id = sender_user.clone().unwrap_or_else(|| {
UserId::parse_with_server_name("", db.globals.server_name()) UserId::parse_with_server_name("", services().globals.server_name())
.expect("we know this is valid") .expect("we know this is valid")
}); });
@ -261,7 +259,7 @@ where
.and_then(|auth| auth.get("session")) .and_then(|auth| auth.get("session"))
.and_then(|session| session.as_str()) .and_then(|session| session.as_str())
.and_then(|session| { .and_then(|session| {
db.uiaa.get_uiaa_request( services().uiaa.get_uiaa_request(
&user_id, &user_id,
&sender_device.clone().unwrap_or_else(|| "".into()), &sender_device.clone().unwrap_or_else(|| "".into()),
session, session,

File diff suppressed because it is too large Load diff

View file

@ -30,7 +30,7 @@ pub trait KeyValueDatabaseEngine: Send + Sync {
fn open(config: &Config) -> Result<Self> fn open(config: &Config) -> Result<Self>
where where
Self: Sized; Self: Sized;
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn Tree>>; fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>;
fn flush(&self) -> Result<()>; fn flush(&self) -> Result<()>;
fn cleanup(&self) -> Result<()> { fn cleanup(&self) -> Result<()> {
Ok(()) Ok(())
@ -40,7 +40,7 @@ pub trait KeyValueDatabaseEngine: Send + Sync {
} }
} }
pub trait KeyValueTree: Send + Sync { pub trait KvTree: Send + Sync {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>; fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>>;
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>;

View file

@ -1,4 +1,4 @@
use super::{super::Config, watchers::Watchers, DatabaseEngine, Tree}; use super::{super::Config, watchers::Watchers, KvTree, KeyValueDatabaseEngine};
use crate::{utils, Result}; use crate::{utils, Result};
use std::{ use std::{
future::Future, future::Future,
@ -51,7 +51,7 @@ fn db_options(max_open_files: i32, rocksdb_cache: &rocksdb::Cache) -> rocksdb::O
db_opts db_opts
} }
impl DatabaseEngine for Arc<Engine> { impl KeyValueDatabaseEngine for Arc<Engine> {
fn open(config: &Config) -> Result<Self> { fn open(config: &Config) -> Result<Self> {
let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize;
let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes).unwrap(); let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes).unwrap();
@ -83,7 +83,7 @@ impl DatabaseEngine for Arc<Engine> {
})) }))
} }
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn Tree>> { fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>> {
if !self.old_cfs.contains(&name.to_owned()) { if !self.old_cfs.contains(&name.to_owned()) {
// Create if it didn't exist // Create if it didn't exist
let _ = self let _ = self
@ -129,7 +129,7 @@ impl RocksDbEngineTree<'_> {
} }
} }
impl Tree for RocksDbEngineTree<'_> { impl KvTree for RocksDbEngineTree<'_> {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
Ok(self.db.rocks.get_cf(&self.cf(), key)?) Ok(self.db.rocks.get_cf(&self.cf(), key)?)
} }

View file

@ -1,4 +1,4 @@
use super::{watchers::Watchers, DatabaseEngine, Tree}; use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
use crate::{database::Config, Result}; use crate::{database::Config, Result};
use parking_lot::{Mutex, MutexGuard}; use parking_lot::{Mutex, MutexGuard};
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
@ -80,7 +80,7 @@ impl Engine {
} }
} }
impl DatabaseEngine for Arc<Engine> { impl KeyValueDatabaseEngine for Arc<Engine> {
fn open(config: &Config) -> Result<Self> { fn open(config: &Config) -> Result<Self> {
let path = Path::new(&config.database_path).join("conduit.db"); let path = Path::new(&config.database_path).join("conduit.db");
@ -105,7 +105,7 @@ impl DatabaseEngine for Arc<Engine> {
Ok(arc) Ok(arc)
} }
fn open_tree(&self, name: &str) -> Result<Arc<dyn Tree>> { fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?; self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )", name), [])?;
Ok(Arc::new(SqliteTable { Ok(Arc::new(SqliteTable {
@ -189,7 +189,7 @@ impl SqliteTable {
} }
} }
impl Tree for SqliteTable { impl KvTree for SqliteTable {
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
self.get_with_guard(self.engine.read_lock(), key) self.get_with_guard(self.engine.read_lock(), key)
} }

View file

@ -1,6 +1,8 @@
use crate::{database::KeyValueDatabase, service, utils, Error};
impl service::appservice::Data for KeyValueDatabase { impl service::appservice::Data for KeyValueDatabase {
/// Registers an appservice and returns the ID to the caller /// Registers an appservice and returns the ID to the caller
pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String> { fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String> {
// TODO: Rumaify // TODO: Rumaify
let id = yaml.get("id").unwrap().as_str().unwrap(); let id = yaml.get("id").unwrap().as_str().unwrap();
self.id_appserviceregistrations.insert( self.id_appserviceregistrations.insert(

View file

@ -1,13 +1,13 @@
mod account_data; //mod account_data;
mod admin; //mod admin;
mod appservice; mod appservice;
mod globals; //mod globals;
mod key_backups; //mod key_backups;
mod media; //mod media;
mod pdu; //mod pdu;
mod pusher; mod pusher;
mod rooms; mod rooms;
mod sending; //mod sending;
mod transaction_ids; mod transaction_ids;
mod uiaa; mod uiaa;
mod users; mod users;

View file

@ -1,3 +1,7 @@
use ruma::{UserId, api::client::push::{set_pusher, get_pushers}};
use crate::{service, database::KeyValueDatabase, Error};
impl service::pusher::Data for KeyValueDatabase { impl service::pusher::Data for KeyValueDatabase {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> { fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()> {
let mut key = sender.as_bytes().to_vec(); let mut key = sender.as_bytes().to_vec();

View file

@ -1,4 +1,8 @@
impl service::room::alias::Data for KeyValueDatabase { use ruma::{RoomId, RoomAliasId, api::client::error::ErrorKind};
use crate::{service, database::KeyValueDatabase, utils, Error, services};
impl service::rooms::alias::Data for KeyValueDatabase {
fn set_alias( fn set_alias(
&self, &self,
alias: &RoomAliasId, alias: &RoomAliasId,
@ -8,7 +12,7 @@ impl service::room::alias::Data for KeyValueDatabase {
.insert(alias.alias().as_bytes(), room_id.as_bytes())?; .insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec(); let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xff); aliasid.push(0xff);
aliasid.extend_from_slice(&globals.next_count()?.to_be_bytes()); aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, &*alias.as_bytes())?; self.aliasid_alias.insert(&aliasid, &*alias.as_bytes())?;
Ok(()) Ok(())
} }

View file

@ -1,10 +1,14 @@
impl service::room::directory::Data for KeyValueDatabase { use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, utils, Error};
impl service::rooms::directory::Data for KeyValueDatabase {
fn set_public(&self, room_id: &RoomId) -> Result<()> { fn set_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.insert(room_id.as_bytes(), &[])?; self.publicroomids.insert(room_id.as_bytes(), &[])
} }
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.remove(room_id.as_bytes())?; self.publicroomids.remove(room_id.as_bytes())
} }
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {

View file

@ -0,0 +1,3 @@
mod presence;
mod typing;
mod read_receipt;

View file

@ -1,4 +1,10 @@
impl service::room::edus::presence::Data for KeyValueDatabase { use std::collections::HashMap;
use ruma::{UserId, RoomId, events::presence::PresenceEvent, presence::PresenceState, UInt};
use crate::{service, database::KeyValueDatabase, utils, Error, services};
impl service::rooms::edus::presence::Data for KeyValueDatabase {
fn update_presence( fn update_presence(
&self, &self,
user_id: &UserId, user_id: &UserId,
@ -7,7 +13,7 @@ impl service::room::edus::presence::Data for KeyValueDatabase {
) -> Result<()> { ) -> Result<()> {
// TODO: Remove old entry? Or maybe just wipe completely from time to time? // TODO: Remove old entry? Or maybe just wipe completely from time to time?
let count = globals.next_count()?.to_be_bytes(); let count = services().globals.next_count()?.to_be_bytes();
let mut presence_id = room_id.as_bytes().to_vec(); let mut presence_id = room_id.as_bytes().to_vec();
presence_id.push(0xff); presence_id.push(0xff);
@ -101,6 +107,7 @@ impl service::room::edus::presence::Data for KeyValueDatabase {
Ok(hashmap) Ok(hashmap)
} }
/*
fn presence_maintain(&self, db: Arc<TokioRwLock<Database>>) { fn presence_maintain(&self, db: Arc<TokioRwLock<Database>>) {
// TODO @M0dEx: move this to a timed tasks module // TODO @M0dEx: move this to a timed tasks module
tokio::spawn(async move { tokio::spawn(async move {
@ -117,6 +124,7 @@ impl service::room::edus::presence::Data for KeyValueDatabase {
} }
}); });
} }
*/
} }
fn parse_presence_event(bytes: &[u8]) -> Result<PresenceEvent> { fn parse_presence_event(bytes: &[u8]) -> Result<PresenceEvent> {

View file

@ -1,4 +1,10 @@
impl service::room::edus::read_receipt::Data for KeyValueDatabase { use std::mem;
use ruma::{UserId, RoomId, events::receipt::ReceiptEvent, serde::Raw, signatures::CanonicalJsonObject};
use crate::{database::KeyValueDatabase, service, utils, Error, services};
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
fn readreceipt_update( fn readreceipt_update(
&self, &self,
user_id: &UserId, user_id: &UserId,
@ -28,7 +34,7 @@ impl service::room::edus::read_receipt::Data for KeyValueDatabase {
} }
let mut room_latest_id = prefix; let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&globals.next_count()?.to_be_bytes()); room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
room_latest_id.push(0xff); room_latest_id.push(0xff);
room_latest_id.extend_from_slice(user_id.as_bytes()); room_latest_id.extend_from_slice(user_id.as_bytes());
@ -40,7 +46,7 @@ impl service::room::edus::read_receipt::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
pub fn readreceipts_since<'a>( fn readreceipts_since<'a>(
&'a self, &'a self,
room_id: &RoomId, room_id: &RoomId,
since: u64, since: u64,
@ -102,7 +108,7 @@ impl service::room::edus::read_receipt::Data for KeyValueDatabase {
.insert(&key, &count.to_be_bytes())?; .insert(&key, &count.to_be_bytes())?;
self.roomuserid_lastprivatereadupdate self.roomuserid_lastprivatereadupdate
.insert(&key, &globals.next_count()?.to_be_bytes())?; .insert(&key, &services().globals.next_count()?.to_be_bytes())
} }
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> { fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {

View file

@ -1,15 +1,20 @@
impl service::room::edus::typing::Data for KeyValueDatabase { use std::collections::HashSet;
use ruma::{UserId, RoomId};
use crate::{database::KeyValueDatabase, service, utils, Error, services};
impl service::rooms::edus::typing::Data for KeyValueDatabase {
fn typing_add( fn typing_add(
&self, &self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
timeout: u64, timeout: u64,
globals: &super::super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
let count = globals.next_count()?.to_be_bytes(); let count = services().globals.next_count()?.to_be_bytes();
let mut room_typing_id = prefix; let mut room_typing_id = prefix;
room_typing_id.extend_from_slice(&timeout.to_be_bytes()); room_typing_id.extend_from_slice(&timeout.to_be_bytes());
@ -49,7 +54,7 @@ impl service::room::edus::typing::Data for KeyValueDatabase {
if found_outdated { if found_outdated {
self.roomid_lasttypingupdate self.roomid_lasttypingupdate
.insert(room_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; .insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
} }
Ok(()) Ok(())

View file

@ -1,4 +1,8 @@
impl service::room::lazy_load::Data for KeyValueDatabase { use ruma::{UserId, DeviceId, RoomId};
use crate::{service, database::KeyValueDatabase};
impl service::rooms::lazy_loading::Data for KeyValueDatabase {
fn lazy_load_was_sent_before( fn lazy_load_was_sent_before(
&self, &self,
user_id: &UserId, user_id: &UserId,

View file

@ -1,4 +1,8 @@
impl service::room::metadata::Data for KeyValueDatabase { use ruma::RoomId;
use crate::{service, database::KeyValueDatabase};
impl service::rooms::metadata::Data for KeyValueDatabase {
fn exists(&self, room_id: &RoomId) -> Result<bool> { fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match self.get_shortroomid(room_id)? { let prefix = match self.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(), Some(b) => b.to_be_bytes().to_vec(),

View file

@ -1,14 +1,13 @@
mod state;
mod alias; mod alias;
mod directory; mod directory;
mod edus; mod edus;
mod event_handler; //mod event_handler;
mod lazy_loading; mod lazy_load;
mod metadata; mod metadata;
mod outlier; mod outlier;
mod pdu_metadata; mod pdu_metadata;
mod search; mod search;
mod short; //mod short;
mod state; mod state;
mod state_accessor; mod state_accessor;
mod state_cache; mod state_cache;

View file

@ -1,4 +1,8 @@
impl service::room::outlier::Data for KeyValueDatabase { use ruma::{EventId, signatures::CanonicalJsonObject};
use crate::{service, database::KeyValueDatabase, PduEvent, Error};
impl service::rooms::outlier::Data for KeyValueDatabase {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu self.eventid_outlierpdu
.get(event_id.as_bytes())? .get(event_id.as_bytes())?

View file

@ -1,4 +1,10 @@
impl service::room::pdu_metadata::Data for KeyValueDatabase { use std::sync::Arc;
use ruma::{RoomId, EventId};
use crate::{service, database::KeyValueDatabase};
impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> { fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
for prev in event_ids { for prev in event_ids {
let mut key = room_id.as_bytes().to_vec(); let mut key = room_id.as_bytes().to_vec();

View file

@ -1,7 +1,12 @@
impl service::room::search::Data for KeyValueDatabase { use std::mem::size_of;
use ruma::RoomId;
use crate::{service, database::KeyValueDatabase, utils};
impl service::rooms::search::Data for KeyValueDatabase {
fn index_pdu<'a>(&self, room_id: &RoomId, pdu_id: u64, message_body: String) -> Result<()> { fn index_pdu<'a>(&self, room_id: &RoomId, pdu_id: u64, message_body: String) -> Result<()> {
let mut batch = body let mut batch = message_body
.split_terminator(|c: char| !c.is_alphanumeric()) .split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty()) .filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50) .filter(|word| word.len() <= 50)
@ -14,7 +19,7 @@ impl service::room::search::Data for KeyValueDatabase {
(key, Vec::new()) (key, Vec::new())
}); });
self.tokenids.insert_batch(&mut batch)?; self.tokenids.insert_batch(&mut batch)
} }
fn search_pdus<'a>( fn search_pdus<'a>(
@ -64,3 +69,4 @@ impl service::room::search::Data for KeyValueDatabase {
) )
})) }))
} }
}

View file

@ -1,4 +1,11 @@
impl service::room::state::Data for KeyValueDatabase { use ruma::{RoomId, EventId};
use std::sync::Arc;
use std::{sync::MutexGuard, collections::HashSet};
use std::fmt::Debug;
use crate::{service, database::KeyValueDatabase, utils, Error};
impl service::rooms::state::Data for KeyValueDatabase {
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> { fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash self.roomid_shortstatehash
.get(room_id.as_bytes())? .get(room_id.as_bytes())?
@ -9,21 +16,21 @@ impl service::room::state::Data for KeyValueDatabase {
}) })
} }
fn set_room_state(&self, room_id: &RoomId, new_shortstatehash: u64 fn set_room_state(&self, room_id: &RoomId, new_shortstatehash: u64,
_mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> { ) -> Result<()> {
self.roomid_shortstatehash self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(()) Ok(())
} }
fn set_event_state(&self) -> Result<()> { fn set_event_state(&self, shorteventid: Vec<u8>, shortstatehash: Vec<u8>) -> Result<()> {
db.shorteventid_shortstatehash self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(()) Ok(())
} }
fn get_pdu_leaves(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> { fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
@ -38,11 +45,11 @@ impl service::room::state::Data for KeyValueDatabase {
.collect() .collect()
} }
fn set_forward_extremities( fn set_forward_extremities<'a>(
&self, &self,
room_id: &RoomId, room_id: &RoomId,
event_ids: impl IntoIterator<Item = &'a EventId> + Debug, event_ids: impl IntoIterator<Item = &'a EventId> + Debug,
_mutex_lock: &MutexGuard<'_, StateLock>, // Take mutex guard to make sure users get the room state mutex _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> { ) -> Result<()> {
let mut prefix = room_id.as_bytes().to_vec(); let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);

View file

@ -1,4 +1,11 @@
impl service::room::state_accessor::Data for KeyValueDatabase { use std::{collections::{BTreeMap, HashMap}, sync::Arc};
use crate::{database::KeyValueDatabase, service, PduEvent, Error, utils};
use async_trait::async_trait;
use ruma::{EventId, events::StateEventType, RoomId};
#[async_trait]
impl service::rooms::state_accessor::Data for KeyValueDatabase {
async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> { async fn state_full_ids(&self, shortstatehash: u64) -> Result<BTreeMap<u64, Arc<EventId>>> {
let full_state = self let full_state = self
.load_shortstatehash_info(shortstatehash)? .load_shortstatehash_info(shortstatehash)?
@ -149,3 +156,4 @@ impl service::room::state_accessor::Data for KeyValueDatabase {
Ok(None) Ok(None)
} }
} }
}

View file

@ -1,8 +1,12 @@
impl service::room::state_cache::Data for KeyValueDatabase { use ruma::{UserId, RoomId};
fn mark_as_once_joined(user_id: &UserId, room_id: &RoomId) -> Result<()> {
use crate::{service, database::KeyValueDatabase};
impl service::rooms::state_cache::Data for KeyValueDatabase {
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xff);
userroom_id.extend_from_slice(room_id.as_bytes()); userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[])?; self.roomuseroncejoinedids.insert(&userroom_id, &[])
} }
} }

View file

@ -1,11 +1,20 @@
impl service::room::state_compressor::Data for KeyValueDatabase { use std::{collections::HashSet, mem::size_of};
fn get_statediff(shortstatehash: u64) -> Result<StateDiff> {
use crate::{service::{self, rooms::state_compressor::data::StateDiff}, database::KeyValueDatabase, Error, utils};
impl service::rooms::state_compressor::Data for KeyValueDatabase {
fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self let value = self
.shortstatehash_statediff .shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())? .get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?; .ok_or_else(|| Error::bad_database("State hash does not exist"))?;
let parent = let parent =
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length"); utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
let parent = if parent != 0 {
Some(parent)
} else {
None
};
let mut add_mode = true; let mut add_mode = true;
let mut added = HashSet::new(); let mut added = HashSet::new();
@ -26,10 +35,10 @@ impl service::room::state_compressor::Data for KeyValueDatabase {
i += 2 * size_of::<u64>(); i += 2 * size_of::<u64>();
} }
StateDiff { parent, added, removed } Ok(StateDiff { parent, added, removed })
} }
fn save_statediff(shortstatehash: u64, diff: StateDiff) -> Result<()> { fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> {
let mut value = diff.parent.to_be_bytes().to_vec(); let mut value = diff.parent.to_be_bytes().to_vec();
for new in &diff.new { for new in &diff.new {
value.extend_from_slice(&new[..]); value.extend_from_slice(&new[..]);
@ -43,6 +52,6 @@ impl service::room::state_compressor::Data for KeyValueDatabase {
} }
self.shortstatehash_statediff self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)?; .insert(&shortstatehash.to_be_bytes(), &value)
} }
} }

View file

@ -1,4 +1,11 @@
impl service::room::timeline::Data for KeyValueDatabase { use std::{collections::hash_map, mem::size_of, sync::Arc};
use ruma::{UserId, RoomId, api::client::error::ErrorKind, EventId, signatures::CanonicalJsonObject};
use tracing::error;
use crate::{service, database::KeyValueDatabase, utils, Error, PduEvent};
impl service::rooms::timeline::Data for KeyValueDatabase {
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> { fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<u64> {
match self match self
.lasttimelinecount_cache .lasttimelinecount_cache
@ -37,7 +44,7 @@ impl service::room::timeline::Data for KeyValueDatabase {
} }
/// Returns the json of a pdu. /// Returns the json of a pdu.
pub fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> { fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map_or_else( .map_or_else(
@ -55,7 +62,7 @@ impl service::room::timeline::Data for KeyValueDatabase {
} }
/// Returns the json of a pdu. /// Returns the json of a pdu.
pub fn get_non_outlier_pdu_json( fn get_non_outlier_pdu_json(
&self, &self,
event_id: &EventId, event_id: &EventId,
) -> Result<Option<CanonicalJsonObject>> { ) -> Result<Option<CanonicalJsonObject>> {
@ -74,14 +81,14 @@ impl service::room::timeline::Data for KeyValueDatabase {
} }
/// Returns the pdu's id. /// Returns the pdu's id.
pub fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
self.eventid_pduid.get(event_id.as_bytes()) self.eventid_pduid.get(event_id.as_bytes())
} }
/// Returns the pdu. /// Returns the pdu.
/// ///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline. /// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> { fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_pduid self.eventid_pduid
.get(event_id.as_bytes())? .get(event_id.as_bytes())?
.map(|pduid| { .map(|pduid| {
@ -99,7 +106,7 @@ impl service::room::timeline::Data for KeyValueDatabase {
/// Returns the pdu. /// Returns the pdu.
/// ///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline. /// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> { fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) {
return Ok(Some(Arc::clone(p))); return Ok(Some(Arc::clone(p)));
} }
@ -135,7 +142,7 @@ impl service::room::timeline::Data for KeyValueDatabase {
/// Returns the pdu. /// Returns the pdu.
/// ///
/// This does __NOT__ check the outliers `Tree`. /// This does __NOT__ check the outliers `Tree`.
pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> { fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some( Ok(Some(
serde_json::from_slice(&pdu) serde_json::from_slice(&pdu)
@ -145,7 +152,7 @@ impl service::room::timeline::Data for KeyValueDatabase {
} }
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`. /// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> { fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some( Ok(Some(
serde_json::from_slice(&pdu) serde_json::from_slice(&pdu)
@ -155,7 +162,7 @@ impl service::room::timeline::Data for KeyValueDatabase {
} }
/// Returns the `count` of this pdu's id. /// Returns the `count` of this pdu's id.
pub fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> { fn pdu_count(&self, pdu_id: &[u8]) -> Result<u64> {
utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..]) utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes.")) .map_err(|_| Error::bad_database("PDU has invalid count bytes."))
} }
@ -178,7 +185,7 @@ impl service::room::timeline::Data for KeyValueDatabase {
/// Returns an iterator over all events in a room that happened after the event with id `since` /// Returns an iterator over all events in a room that happened after the event with id `since`
/// in chronological order. /// in chronological order.
pub fn pdus_since<'a>( fn pdus_since<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
@ -212,7 +219,7 @@ impl service::room::timeline::Data for KeyValueDatabase {
/// Returns an iterator over all events and their tokens in a room that happened before the /// Returns an iterator over all events and their tokens in a room that happened before the
/// event with id `until` in reverse-chronological order. /// event with id `until` in reverse-chronological order.
pub fn pdus_until<'a>( fn pdus_until<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,
@ -246,7 +253,7 @@ impl service::room::timeline::Data for KeyValueDatabase {
})) }))
} }
pub fn pdus_after<'a>( fn pdus_after<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
room_id: &RoomId, room_id: &RoomId,

View file

@ -1,4 +1,8 @@
impl service::room::user::Data for KeyValueDatabase { use ruma::{UserId, RoomId};
use crate::{service, database::KeyValueDatabase, utils, Error};
impl service::rooms::user::Data for KeyValueDatabase {
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
let mut userroom_id = user_id.as_bytes().to_vec(); let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xff); userroom_id.push(0xff);

View file

@ -1,5 +1,9 @@
impl service::pusher::Data for KeyValueDatabase { use ruma::{UserId, DeviceId, TransactionId};
pub fn add_txnid(
use crate::{service, database::KeyValueDatabase};
impl service::transaction_ids::Data for KeyValueDatabase {
fn add_txnid(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: Option<&DeviceId>, device_id: Option<&DeviceId>,
@ -17,7 +21,7 @@ impl service::pusher::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
pub fn existing_txnid( fn existing_txnid(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: Option<&DeviceId>, device_id: Option<&DeviceId>,

View file

@ -1,3 +1,9 @@
use std::io::ErrorKind;
use ruma::{UserId, DeviceId, signatures::CanonicalJsonValue, api::client::uiaa::UiaaInfo};
use crate::{database::KeyValueDatabase, service, Error};
impl service::uiaa::Data for KeyValueDatabase { impl service::uiaa::Data for KeyValueDatabase {
fn set_uiaa_request( fn set_uiaa_request(
&self, &self,

View file

@ -1,11 +1,18 @@
use std::{mem::size_of, collections::BTreeMap};
use ruma::{api::client::{filter::IncomingFilterDefinition, error::ErrorKind, device::Device}, UserId, RoomAliasId, MxcUri, DeviceId, MilliSecondsSinceUnixEpoch, DeviceKeyId, encryption::{OneTimeKey, CrossSigningKey, DeviceKeys}, serde::Raw, events::{AnyToDeviceEvent, StateEventType}, DeviceKeyAlgorithm, UInt};
use tracing::warn;
use crate::{service::{self, users::clean_signatures}, database::KeyValueDatabase, Error, utils, services};
impl service::users::Data for KeyValueDatabase { impl service::users::Data for KeyValueDatabase {
/// Check if a user has an account on this homeserver. /// Check if a user has an account on this homeserver.
pub fn exists(&self, user_id: &UserId) -> Result<bool> { fn exists(&self, user_id: &UserId) -> Result<bool> {
Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
} }
/// Check if account is deactivated /// Check if account is deactivated
pub fn is_deactivated(&self, user_id: &UserId) -> Result<bool> { fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
Ok(self Ok(self
.userid_password .userid_password
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
@ -16,33 +23,13 @@ impl service::users::Data for KeyValueDatabase {
.is_empty()) .is_empty())
} }
/// Check if a user is an admin
pub fn is_admin(
&self,
user_id: &UserId,
rooms: &super::rooms::Rooms,
globals: &super::globals::Globals,
) -> Result<bool> {
let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", globals.server_name()))
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?;
let admin_room_id = rooms.id_from_alias(&admin_room_alias_id)?.unwrap();
rooms.is_joined(user_id, &admin_room_id)
}
/// Create a new user account on this homeserver.
pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
self.set_password(user_id, password)?;
Ok(())
}
/// Returns the number of users registered on this server. /// Returns the number of users registered on this server.
pub fn count(&self) -> Result<usize> { fn count(&self) -> Result<usize> {
Ok(self.userid_password.iter().count()) Ok(self.userid_password.iter().count())
} }
/// Find out which user an access token belongs to. /// Find out which user an access token belongs to.
pub fn find_from_token(&self, token: &str) -> Result<Option<(Box<UserId>, String)>> { fn find_from_token(&self, token: &str) -> Result<Option<(Box<UserId>, String)>> {
self.token_userdeviceid self.token_userdeviceid
.get(token.as_bytes())? .get(token.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
@ -69,7 +56,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Returns an iterator over all users on this homeserver. /// Returns an iterator over all users on this homeserver.
pub fn iter(&self) -> impl Iterator<Item = Result<Box<UserId>>> + '_ { fn iter(&self) -> impl Iterator<Item = Result<Box<UserId>>> + '_ {
self.userid_password.iter().map(|(bytes, _)| { self.userid_password.iter().map(|(bytes, _)| {
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("User ID in userid_password is invalid unicode.") Error::bad_database("User ID in userid_password is invalid unicode.")
@ -81,7 +68,7 @@ impl service::users::Data for KeyValueDatabase {
/// Returns a list of local users as list of usernames. /// Returns a list of local users as list of usernames.
/// ///
/// A user account is considered `local` if the length of it's password is greater then zero. /// A user account is considered `local` if the length of it's password is greater then zero.
pub fn list_local_users(&self) -> Result<Vec<String>> { fn list_local_users(&self) -> Result<Vec<String>> {
let users: Vec<String> = self let users: Vec<String> = self
.userid_password .userid_password
.iter() .iter()
@ -113,7 +100,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Returns the password hash for the given user. /// Returns the password hash for the given user.
pub fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> { fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_password self.userid_password
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
@ -124,7 +111,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Hash and set the user's password to the Argon2 hash /// Hash and set the user's password to the Argon2 hash
pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
if let Some(password) = password { if let Some(password) = password {
if let Ok(hash) = utils::calculate_hash(password) { if let Ok(hash) = utils::calculate_hash(password) {
self.userid_password self.userid_password
@ -143,7 +130,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Returns the displayname of a user on this homeserver. /// Returns the displayname of a user on this homeserver.
pub fn displayname(&self, user_id: &UserId) -> Result<Option<String>> { fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_displayname self.userid_displayname
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
@ -154,7 +141,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change.
pub fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> { fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
if let Some(displayname) = displayname { if let Some(displayname) = displayname {
self.userid_displayname self.userid_displayname
.insert(user_id.as_bytes(), displayname.as_bytes())?; .insert(user_id.as_bytes(), displayname.as_bytes())?;
@ -166,7 +153,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Get the avatar_url of a user. /// Get the avatar_url of a user.
pub fn avatar_url(&self, user_id: &UserId) -> Result<Option<Box<MxcUri>>> { fn avatar_url(&self, user_id: &UserId) -> Result<Option<Box<MxcUri>>> {
self.userid_avatarurl self.userid_avatarurl
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map(|bytes| { .map(|bytes| {
@ -179,7 +166,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Sets a new avatar_url or removes it if avatar_url is None. /// Sets a new avatar_url or removes it if avatar_url is None.
pub fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<Box<MxcUri>>) -> Result<()> { fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<Box<MxcUri>>) -> Result<()> {
if let Some(avatar_url) = avatar_url { if let Some(avatar_url) = avatar_url {
self.userid_avatarurl self.userid_avatarurl
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
@ -191,7 +178,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Get the blurhash of a user. /// Get the blurhash of a user.
pub fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> { fn blurhash(&self, user_id: &UserId) -> Result<Option<String>> {
self.userid_blurhash self.userid_blurhash
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map(|bytes| { .map(|bytes| {
@ -204,7 +191,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Sets a new avatar_url or removes it if avatar_url is None. /// Sets a new avatar_url or removes it if avatar_url is None.
pub fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> { fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
if let Some(blurhash) = blurhash { if let Some(blurhash) = blurhash {
self.userid_blurhash self.userid_blurhash
.insert(user_id.as_bytes(), blurhash.as_bytes())?; .insert(user_id.as_bytes(), blurhash.as_bytes())?;
@ -216,7 +203,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Adds a new device to a user. /// Adds a new device to a user.
pub fn create_device( fn create_device(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -250,7 +237,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Removes a device from a user. /// Removes a device from a user.
pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xff);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
@ -280,7 +267,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Returns an iterator over all device ids of this user. /// Returns an iterator over all device ids of this user.
pub fn all_device_ids<'a>( fn all_device_ids<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
) -> impl Iterator<Item = Result<Box<DeviceId>>> + 'a { ) -> impl Iterator<Item = Result<Box<DeviceId>>> + 'a {
@ -302,7 +289,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Replaces the access token of one device. /// Replaces the access token of one device.
pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xff);
userdeviceid.extend_from_slice(device_id.as_bytes()); userdeviceid.extend_from_slice(device_id.as_bytes());
@ -325,13 +312,12 @@ impl service::users::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
pub fn add_one_time_key( fn add_one_time_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
one_time_key_key: &DeviceKeyId, one_time_key_key: &DeviceKeyId,
one_time_key_value: &Raw<OneTimeKey>, one_time_key_value: &Raw<OneTimeKey>,
globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
@ -356,12 +342,12 @@ impl service::users::Data for KeyValueDatabase {
)?; )?;
self.userid_lastonetimekeyupdate self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
Ok(()) Ok(())
} }
pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> { fn last_one_time_keys_update(&self, user_id: &UserId) -> Result<u64> {
self.userid_lastonetimekeyupdate self.userid_lastonetimekeyupdate
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map(|bytes| { .map(|bytes| {
@ -372,12 +358,11 @@ impl service::users::Data for KeyValueDatabase {
.unwrap_or(Ok(0)) .unwrap_or(Ok(0))
} }
pub fn take_one_time_key( fn take_one_time_key(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
key_algorithm: &DeviceKeyAlgorithm, key_algorithm: &DeviceKeyAlgorithm,
globals: &super::globals::Globals,
) -> Result<Option<(Box<DeviceKeyId>, Raw<OneTimeKey>)>> { ) -> Result<Option<(Box<DeviceKeyId>, Raw<OneTimeKey>)>> {
let mut prefix = user_id.as_bytes().to_vec(); let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xff); prefix.push(0xff);
@ -388,7 +373,7 @@ impl service::users::Data for KeyValueDatabase {
prefix.push(b':'); prefix.push(b':');
self.userid_lastonetimekeyupdate self.userid_lastonetimekeyupdate
.insert(user_id.as_bytes(), &globals.next_count()?.to_be_bytes())?; .insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
self.onetimekeyid_onetimekeys self.onetimekeyid_onetimekeys
.scan_prefix(prefix) .scan_prefix(prefix)
@ -411,7 +396,7 @@ impl service::users::Data for KeyValueDatabase {
.transpose() .transpose()
} }
pub fn count_one_time_keys( fn count_one_time_keys(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -443,13 +428,11 @@ impl service::users::Data for KeyValueDatabase {
Ok(counts) Ok(counts)
} }
pub fn add_device_keys( fn add_device_keys(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
device_keys: &Raw<DeviceKeys>, device_keys: &Raw<DeviceKeys>,
rooms: &super::rooms::Rooms,
globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let mut userdeviceid = user_id.as_bytes().to_vec(); let mut userdeviceid = user_id.as_bytes().to_vec();
userdeviceid.push(0xff); userdeviceid.push(0xff);
@ -460,19 +443,17 @@ impl service::users::Data for KeyValueDatabase {
&serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"),
)?; )?;
self.mark_device_key_update(user_id, rooms, globals)?; self.mark_device_key_update(user_id)?;
Ok(()) Ok(())
} }
pub fn add_cross_signing_keys( fn add_cross_signing_keys(
&self, &self,
user_id: &UserId, user_id: &UserId,
master_key: &Raw<CrossSigningKey>, master_key: &Raw<CrossSigningKey>,
self_signing_key: &Option<Raw<CrossSigningKey>>, self_signing_key: &Option<Raw<CrossSigningKey>>,
user_signing_key: &Option<Raw<CrossSigningKey>>, user_signing_key: &Option<Raw<CrossSigningKey>>,
rooms: &super::rooms::Rooms,
globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
// TODO: Check signatures // TODO: Check signatures
@ -575,19 +556,17 @@ impl service::users::Data for KeyValueDatabase {
.insert(user_id.as_bytes(), &user_signing_key_key)?; .insert(user_id.as_bytes(), &user_signing_key_key)?;
} }
self.mark_device_key_update(user_id, rooms, globals)?; self.mark_device_key_update(user_id)?;
Ok(()) Ok(())
} }
pub fn sign_key( fn sign_key(
&self, &self,
target_id: &UserId, target_id: &UserId,
key_id: &str, key_id: &str,
signature: (String, String), signature: (String, String),
sender_id: &UserId, sender_id: &UserId,
rooms: &super::rooms::Rooms,
globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let mut key = target_id.as_bytes().to_vec(); let mut key = target_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
@ -619,12 +598,12 @@ impl service::users::Data for KeyValueDatabase {
)?; )?;
// TODO: Should we notify about this change? // TODO: Should we notify about this change?
self.mark_device_key_update(target_id, rooms, globals)?; self.mark_device_key_update(target_id)?;
Ok(()) Ok(())
} }
pub fn keys_changed<'a>( fn keys_changed<'a>(
&'a self, &'a self,
user_or_room_id: &str, user_or_room_id: &str,
from: u64, from: u64,
@ -662,16 +641,14 @@ impl service::users::Data for KeyValueDatabase {
}) })
} }
pub fn mark_device_key_update( fn mark_device_key_update(
&self, &self,
user_id: &UserId, user_id: &UserId,
rooms: &super::rooms::Rooms,
globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let count = globals.next_count()?.to_be_bytes(); let count = services().globals.next_count()?.to_be_bytes();
for room_id in rooms.rooms_joined(user_id).filter_map(|r| r.ok()) { for room_id in services().rooms.rooms_joined(user_id).filter_map(|r| r.ok()) {
// Don't send key updates to unencrypted rooms // Don't send key updates to unencrypted rooms
if rooms if services().rooms
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")? .room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
.is_none() .is_none()
{ {
@ -693,7 +670,7 @@ impl service::users::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
pub fn get_device_keys( fn get_device_keys(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -709,7 +686,7 @@ impl service::users::Data for KeyValueDatabase {
}) })
} }
pub fn get_master_key<F: Fn(&UserId) -> bool>( fn get_master_key<F: Fn(&UserId) -> bool>(
&self, &self,
user_id: &UserId, user_id: &UserId,
allowed_signatures: F, allowed_signatures: F,
@ -730,7 +707,7 @@ impl service::users::Data for KeyValueDatabase {
}) })
} }
pub fn get_self_signing_key<F: Fn(&UserId) -> bool>( fn get_self_signing_key<F: Fn(&UserId) -> bool>(
&self, &self,
user_id: &UserId, user_id: &UserId,
allowed_signatures: F, allowed_signatures: F,
@ -751,7 +728,7 @@ impl service::users::Data for KeyValueDatabase {
}) })
} }
pub fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> { fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
self.userid_usersigningkeyid self.userid_usersigningkeyid
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map_or(Ok(None), |key| { .map_or(Ok(None), |key| {
@ -763,20 +740,19 @@ impl service::users::Data for KeyValueDatabase {
}) })
} }
pub fn add_to_device_event( fn add_to_device_event(
&self, &self,
sender: &UserId, sender: &UserId,
target_user_id: &UserId, target_user_id: &UserId,
target_device_id: &DeviceId, target_device_id: &DeviceId,
event_type: &str, event_type: &str,
content: serde_json::Value, content: serde_json::Value,
globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let mut key = target_user_id.as_bytes().to_vec(); let mut key = target_user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
key.extend_from_slice(target_device_id.as_bytes()); key.extend_from_slice(target_device_id.as_bytes());
key.push(0xff); key.push(0xff);
key.extend_from_slice(&globals.next_count()?.to_be_bytes()); key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
let mut json = serde_json::Map::new(); let mut json = serde_json::Map::new();
json.insert("type".to_owned(), event_type.to_owned().into()); json.insert("type".to_owned(), event_type.to_owned().into());
@ -790,7 +766,7 @@ impl service::users::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
pub fn get_to_device_events( fn get_to_device_events(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -812,7 +788,7 @@ impl service::users::Data for KeyValueDatabase {
Ok(events) Ok(events)
} }
pub fn remove_to_device_events( fn remove_to_device_events(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -833,7 +809,7 @@ impl service::users::Data for KeyValueDatabase {
.map(|(key, _)| { .map(|(key, _)| {
Ok::<_, Error>(( Ok::<_, Error>((
key.clone(), key.clone(),
utils::u64_from_bytes(&key[key.len() - mem::size_of::<u64>()..key.len()]) utils::u64_from_bytes(&key[key.len() - size_of::<u64>()..key.len()])
.map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?,
)) ))
}) })
@ -846,7 +822,7 @@ impl service::users::Data for KeyValueDatabase {
Ok(()) Ok(())
} }
pub fn update_device_metadata( fn update_device_metadata(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -871,7 +847,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Get device metadata. /// Get device metadata.
pub fn get_device_metadata( fn get_device_metadata(
&self, &self,
user_id: &UserId, user_id: &UserId,
device_id: &DeviceId, device_id: &DeviceId,
@ -889,7 +865,7 @@ impl service::users::Data for KeyValueDatabase {
}) })
} }
pub fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> { fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
self.userid_devicelistversion self.userid_devicelistversion
.get(user_id.as_bytes())? .get(user_id.as_bytes())?
.map_or(Ok(None), |bytes| { .map_or(Ok(None), |bytes| {
@ -899,7 +875,7 @@ impl service::users::Data for KeyValueDatabase {
}) })
} }
pub fn all_devices_metadata<'a>( fn all_devices_metadata<'a>(
&'a self, &'a self,
user_id: &UserId, user_id: &UserId,
) -> impl Iterator<Item = Result<Device>> + 'a { ) -> impl Iterator<Item = Result<Device>> + 'a {
@ -915,7 +891,7 @@ impl service::users::Data for KeyValueDatabase {
} }
/// Creates a new sync filter. Returns the filter id. /// Creates a new sync filter. Returns the filter id.
pub fn create_filter( fn create_filter(
&self, &self,
user_id: &UserId, user_id: &UserId,
filter: &IncomingFilterDefinition, filter: &IncomingFilterDefinition,
@ -934,7 +910,7 @@ impl service::users::Data for KeyValueDatabase {
Ok(filter_id) Ok(filter_id)
} }
pub fn get_filter( fn get_filter(
&self, &self,
user_id: &UserId, user_id: &UserId,
filter_id: &str, filter_id: &str,

View file

@ -1,20 +1,7 @@
pub mod abstraction; pub mod abstraction;
pub mod key_value;
pub mod account_data; use crate::{utils, Config, Error, Result, service::{users, globals, uiaa, rooms, account_data, media, key_backups, transaction_ids, sending, admin::{self, create_admin_room}, appservice, pusher}};
pub mod admin;
pub mod appservice;
pub mod globals;
pub mod key_backups;
pub mod media;
pub mod pusher;
pub mod rooms;
pub mod sending;
pub mod transaction_ids;
pub mod uiaa;
pub mod users;
use self::admin::create_admin_room;
use crate::{utils, Config, Error, Result};
use abstraction::KeyValueDatabaseEngine; use abstraction::KeyValueDatabaseEngine;
use directories::ProjectDirs; use directories::ProjectDirs;
use futures_util::{stream::FuturesUnordered, StreamExt}; use futures_util::{stream::FuturesUnordered, StreamExt};
@ -25,7 +12,7 @@ use ruma::{
GlobalAccountDataEvent, GlobalAccountDataEventType, GlobalAccountDataEvent, GlobalAccountDataEventType,
}, },
push::Ruleset, push::Ruleset,
DeviceId, EventId, RoomId, UserId, DeviceId, EventId, RoomId, UserId, signatures::CanonicalJsonValue,
}; };
use std::{ use std::{
collections::{BTreeMap, HashMap, HashSet}, collections::{BTreeMap, HashMap, HashSet},
@ -38,21 +25,132 @@ use std::{
}; };
use tokio::sync::{mpsc, OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore}; use tokio::sync::{mpsc, OwnedRwLockReadGuard, RwLock as TokioRwLock, Semaphore};
use tracing::{debug, error, info, warn}; use tracing::{debug, error, info, warn};
use abstraction::KvTree;
pub struct KeyValueDatabase { pub struct KeyValueDatabase {
_db: Arc<dyn KeyValueDatabaseEngine>, _db: Arc<dyn KeyValueDatabaseEngine>,
pub globals: globals::Globals,
pub users: users::Users, //pub globals: globals::Globals,
pub uiaa: uiaa::Uiaa, pub(super) global: Arc<dyn KvTree>,
pub rooms: rooms::Rooms, pub(super) server_signingkeys: Arc<dyn KvTree>,
pub account_data: account_data::AccountData,
pub media: media::Media, //pub users: users::Users,
pub key_backups: key_backups::KeyBackups, pub(super) userid_password: Arc<dyn KvTree>,
pub transaction_ids: transaction_ids::TransactionIds, pub(super) userid_displayname: Arc<dyn KvTree>,
pub sending: sending::Sending, pub(super) userid_avatarurl: Arc<dyn KvTree>,
pub admin: admin::Admin, pub(super) userid_blurhash: Arc<dyn KvTree>,
pub appservice: appservice::Appservice, pub(super) userdeviceid_token: Arc<dyn KvTree>,
pub pusher: pusher::PushData, pub(super) userdeviceid_metadata: Arc<dyn KvTree>, // This is also used to check if a device exists
pub(super) userid_devicelistversion: Arc<dyn KvTree>, // DevicelistVersion = u64
pub(super) token_userdeviceid: Arc<dyn KvTree>,
pub(super) onetimekeyid_onetimekeys: Arc<dyn KvTree>, // OneTimeKeyId = UserId + DeviceKeyId
pub(super) userid_lastonetimekeyupdate: Arc<dyn KvTree>, // LastOneTimeKeyUpdate = Count
pub(super) keychangeid_userid: Arc<dyn KvTree>, // KeyChangeId = UserId/RoomId + Count
pub(super) keyid_key: Arc<dyn KvTree>, // KeyId = UserId + KeyId (depends on key type)
pub(super) userid_masterkeyid: Arc<dyn KvTree>,
pub(super) userid_selfsigningkeyid: Arc<dyn KvTree>,
pub(super) userid_usersigningkeyid: Arc<dyn KvTree>,
pub(super) userfilterid_filter: Arc<dyn KvTree>, // UserFilterId = UserId + FilterId
pub(super) todeviceid_events: Arc<dyn KvTree>, // ToDeviceId = UserId + DeviceId + Count
//pub uiaa: uiaa::Uiaa,
pub(super) userdevicesessionid_uiaainfo: Arc<dyn KvTree>, // User-interactive authentication
pub(super) userdevicesessionid_uiaarequest:
RwLock<BTreeMap<(Box<UserId>, Box<DeviceId>, String), CanonicalJsonValue>>,
//pub edus: RoomEdus,
pub(super) readreceiptid_readreceipt: Arc<dyn KvTree>, // ReadReceiptId = RoomId + Count + UserId
pub(super) roomuserid_privateread: Arc<dyn KvTree>, // RoomUserId = Room + User, PrivateRead = Count
pub(super) roomuserid_lastprivatereadupdate: Arc<dyn KvTree>, // LastPrivateReadUpdate = Count
pub(super) typingid_userid: Arc<dyn KvTree>, // TypingId = RoomId + TimeoutTime + Count
pub(super) roomid_lasttypingupdate: Arc<dyn KvTree>, // LastRoomTypingUpdate = Count
pub(super) presenceid_presence: Arc<dyn KvTree>, // PresenceId = RoomId + Count + UserId
pub(super) userid_lastpresenceupdate: Arc<dyn KvTree>, // LastPresenceUpdate = Count
//pub rooms: rooms::Rooms,
pub(super) pduid_pdu: Arc<dyn KvTree>, // PduId = ShortRoomId + Count
pub(super) eventid_pduid: Arc<dyn KvTree>,
pub(super) roomid_pduleaves: Arc<dyn KvTree>,
pub(super) alias_roomid: Arc<dyn KvTree>,
pub(super) aliasid_alias: Arc<dyn KvTree>, // AliasId = RoomId + Count
pub(super) publicroomids: Arc<dyn KvTree>,
pub(super) tokenids: Arc<dyn KvTree>, // TokenId = ShortRoomId + Token + PduIdCount
/// Participating servers in a room.
pub(super) roomserverids: Arc<dyn KvTree>, // RoomServerId = RoomId + ServerName
pub(super) serverroomids: Arc<dyn KvTree>, // ServerRoomId = ServerName + RoomId
pub(super) userroomid_joined: Arc<dyn KvTree>,
pub(super) roomuserid_joined: Arc<dyn KvTree>,
pub(super) roomid_joinedcount: Arc<dyn KvTree>,
pub(super) roomid_invitedcount: Arc<dyn KvTree>,
pub(super) roomuseroncejoinedids: Arc<dyn KvTree>,
pub(super) userroomid_invitestate: Arc<dyn KvTree>, // InviteState = Vec<Raw<Pdu>>
pub(super) roomuserid_invitecount: Arc<dyn KvTree>, // InviteCount = Count
pub(super) userroomid_leftstate: Arc<dyn KvTree>,
pub(super) roomuserid_leftcount: Arc<dyn KvTree>,
pub(super) disabledroomids: Arc<dyn KvTree>, // Rooms where incoming federation handling is disabled
pub(super) lazyloadedids: Arc<dyn KvTree>, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId
pub(super) userroomid_notificationcount: Arc<dyn KvTree>, // NotifyCount = u64
pub(super) userroomid_highlightcount: Arc<dyn KvTree>, // HightlightCount = u64
/// Remember the current state hash of a room.
pub(super) roomid_shortstatehash: Arc<dyn KvTree>,
pub(super) roomsynctoken_shortstatehash: Arc<dyn KvTree>,
/// Remember the state hash at events in the past.
pub(super) shorteventid_shortstatehash: Arc<dyn KvTree>,
/// StateKey = EventType + StateKey, ShortStateKey = Count
pub(super) statekey_shortstatekey: Arc<dyn KvTree>,
pub(super) shortstatekey_statekey: Arc<dyn KvTree>,
pub(super) roomid_shortroomid: Arc<dyn KvTree>,
pub(super) shorteventid_eventid: Arc<dyn KvTree>,
pub(super) eventid_shorteventid: Arc<dyn KvTree>,
pub(super) statehash_shortstatehash: Arc<dyn KvTree>,
pub(super) shortstatehash_statediff: Arc<dyn KvTree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--)
pub(super) shorteventid_authchain: Arc<dyn KvTree>,
/// RoomId + EventId -> outlier PDU.
/// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn.
pub(super) eventid_outlierpdu: Arc<dyn KvTree>,
pub(super) softfailedeventids: Arc<dyn KvTree>,
/// RoomId + EventId -> Parent PDU EventId.
pub(super) referencedevents: Arc<dyn KvTree>,
//pub account_data: account_data::AccountData,
pub(super) roomuserdataid_accountdata: Arc<dyn KvTree>, // RoomUserDataId = Room + User + Count + Type
pub(super) roomusertype_roomuserdataid: Arc<dyn KvTree>, // RoomUserType = Room + User + Type
//pub media: media::Media,
pub(super) mediaid_file: Arc<dyn KvTree>, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType
//pub key_backups: key_backups::KeyBackups,
pub(super) backupid_algorithm: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub(super) backupid_etag: Arc<dyn KvTree>, // BackupId = UserId + Version(Count)
pub(super) backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId
//pub transaction_ids: transaction_ids::TransactionIds,
pub(super) userdevicetxnid_response: Arc<dyn KvTree>, // Response can be empty (/sendToDevice) or the event id (/send)
//pub sending: sending::Sending,
pub(super) servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync
pub(super) servernameevent_data: Arc<dyn KvTree>, // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content
pub(super) servercurrentevent_data: Arc<dyn KvTree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content
//pub appservice: appservice::Appservice,
pub(super) id_appserviceregistrations: Arc<dyn KvTree>,
//pub pusher: pusher::PushData,
pub(super) senderkey_pusher: Arc<dyn KvTree>,
} }
impl KeyValueDatabase { impl KeyValueDatabase {
@ -157,7 +255,6 @@ impl KeyValueDatabase {
let db = Arc::new(TokioRwLock::from(Self { let db = Arc::new(TokioRwLock::from(Self {
_db: builder.clone(), _db: builder.clone(),
users: users::Users {
userid_password: builder.open_tree("userid_password")?, userid_password: builder.open_tree("userid_password")?,
userid_displayname: builder.open_tree("userid_displayname")?, userid_displayname: builder.open_tree("userid_displayname")?,
userid_avatarurl: builder.open_tree("userid_avatarurl")?, userid_avatarurl: builder.open_tree("userid_avatarurl")?,
@ -175,13 +272,9 @@ impl KeyValueDatabase {
userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?,
userfilterid_filter: builder.open_tree("userfilterid_filter")?, userfilterid_filter: builder.open_tree("userfilterid_filter")?,
todeviceid_events: builder.open_tree("todeviceid_events")?, todeviceid_events: builder.open_tree("todeviceid_events")?,
},
uiaa: uiaa::Uiaa {
userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?,
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
},
rooms: rooms::Rooms {
edus: rooms::RoomEdus {
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
roomuserid_lastprivatereadupdate: builder roomuserid_lastprivatereadupdate: builder
@ -190,7 +283,6 @@ impl KeyValueDatabase {
roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?, roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?,
presenceid_presence: builder.open_tree("presenceid_presence")?, presenceid_presence: builder.open_tree("presenceid_presence")?,
userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?, userid_lastpresenceupdate: builder.open_tree("userid_lastpresenceupdate")?,
},
pduid_pdu: builder.open_tree("pduid_pdu")?, pduid_pdu: builder.open_tree("pduid_pdu")?,
eventid_pduid: builder.open_tree("eventid_pduid")?, eventid_pduid: builder.open_tree("eventid_pduid")?,
roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, roomid_pduleaves: builder.open_tree("roomid_pduleaves")?,
@ -239,74 +331,23 @@ impl KeyValueDatabase {
softfailedeventids: builder.open_tree("softfailedeventids")?, softfailedeventids: builder.open_tree("softfailedeventids")?,
referencedevents: builder.open_tree("referencedevents")?, referencedevents: builder.open_tree("referencedevents")?,
pdu_cache: Mutex::new(LruCache::new(
config
.pdu_cache_capacity
.try_into()
.expect("pdu cache capacity fits into usize"),
)),
auth_chain_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
shorteventid_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
eventidshort_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
shortstatekey_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
statekeyshort_cache: Mutex::new(LruCache::new(
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
)),
our_real_users_cache: RwLock::new(HashMap::new()),
appservice_in_room_cache: RwLock::new(HashMap::new()),
lazy_load_waiting: Mutex::new(HashMap::new()),
stateinfo_cache: Mutex::new(LruCache::new(
(100.0 * config.conduit_cache_capacity_modifier) as usize,
)),
lasttimelinecount_cache: Mutex::new(HashMap::new()),
},
account_data: account_data::AccountData {
roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?,
roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?,
},
media: media::Media {
mediaid_file: builder.open_tree("mediaid_file")?, mediaid_file: builder.open_tree("mediaid_file")?,
},
key_backups: key_backups::KeyBackups {
backupid_algorithm: builder.open_tree("backupid_algorithm")?, backupid_algorithm: builder.open_tree("backupid_algorithm")?,
backupid_etag: builder.open_tree("backupid_etag")?, backupid_etag: builder.open_tree("backupid_etag")?,
backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, backupkeyid_backup: builder.open_tree("backupkeyid_backup")?,
},
transaction_ids: transaction_ids::TransactionIds {
userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?,
},
sending: sending::Sending {
servername_educount: builder.open_tree("servername_educount")?, servername_educount: builder.open_tree("servername_educount")?,
servernameevent_data: builder.open_tree("servernameevent_data")?, servernameevent_data: builder.open_tree("servernameevent_data")?,
servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, servercurrentevent_data: builder.open_tree("servercurrentevent_data")?,
maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)),
sender: sending_sender,
},
admin: admin::Admin {
sender: admin_sender,
},
appservice: appservice::Appservice {
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?,
},
pusher: pusher::PushData {
senderkey_pusher: builder.open_tree("senderkey_pusher")?, senderkey_pusher: builder.open_tree("senderkey_pusher")?,
}, global: builder.open_tree("global")?,
globals: globals::Globals::load( server_signingkeys: builder.open_tree("server_signingkeys")?,
builder.open_tree("global")?,
builder.open_tree("server_signingkeys")?,
config.clone(),
)?,
})); }));
// TODO: do this after constructing the db
let guard = db.read().await; let guard = db.read().await;
// Matrix resource ownership is based on the server name; changing it // Matrix resource ownership is based on the server name; changing it
@ -744,7 +785,7 @@ impl KeyValueDatabase {
.bump_database_version(latest_database_version)?; .bump_database_version(latest_database_version)?;
// Create the admin room and server user on first run // Create the admin room and server user on first run
create_admin_room(&guard).await?; create_admin_room().await?;
warn!( warn!(
"Created new {} database with version {}", "Created new {} database with version {}",

View file

@ -9,17 +9,26 @@
mod config; mod config;
mod database; mod database;
mod error; mod service;
mod pdu; pub mod api;
mod ruma_wrapper;
mod utils; mod utils;
pub mod appservice_server; use std::cell::Cell;
pub mod client_server;
pub mod server_server;
pub use config::Config; pub use config::Config;
pub use database::Database; pub use utils::error::{Error, Result};
pub use error::{Error, Result}; pub use service::{Services, pdu::PduEvent};
pub use pdu::PduEvent; pub use api::ruma_wrapper::{Ruma, RumaResponse};
pub use ruma_wrapper::{Ruma, RumaResponse};
use crate::database::KeyValueDatabase;
pub static SERVICES: Cell<Option<ServicesEnum>> = Cell::new(None);
enum ServicesEnum {
Rocksdb(Services<KeyValueDatabase>)
}
pub fn services() -> Services {
SERVICES.get().unwrap()
}

View file

@ -46,47 +46,44 @@ use tikv_jemallocator::Jemalloc;
#[global_allocator] #[global_allocator]
static GLOBAL: Jemalloc = Jemalloc; static GLOBAL: Jemalloc = Jemalloc;
lazy_static! {
static ref DB: Database = {
let raw_config =
Figment::new()
.merge(
Toml::file(Env::var("CONDUIT_CONFIG").expect(
"The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml",
))
.nested(),
)
.merge(Env::prefixed("CONDUIT_").global());
let config = match raw_config.extract::<Config>() {
Ok(s) => s,
Err(e) => {
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e);
std::process::exit(1);
}
};
config.warn_deprecated();
let db = match Database::load_or_create(&config).await {
Ok(db) => db,
Err(e) => {
eprintln!(
"The database couldn't be loaded or created. The following error occured: {}",
e
);
std::process::exit(1);
}
};
};
}
#[tokio::main] #[tokio::main]
async fn main() { async fn main() {
lazy_static::initialize(&DB); // Initialize DB
let raw_config =
Figment::new()
.merge(
Toml::file(Env::var("CONDUIT_CONFIG").expect(
"The CONDUIT_CONFIG env var needs to be set. Example: /etc/conduit.toml",
))
.nested(),
)
.merge(Env::prefixed("CONDUIT_").global());
let config = match raw_config.extract::<Config>() {
Ok(s) => s,
Err(e) => {
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {}", e);
std::process::exit(1);
}
};
config.warn_deprecated();
let db = match KeyValueDatabase::load_or_create(&config).await {
Ok(db) => db,
Err(e) => {
eprintln!(
"The database couldn't be loaded or created. The following error occured: {}",
e
);
std::process::exit(1);
}
};
SERVICES.set(db).expect("this is the first and only time we initialize the SERVICE static");
let start = async { let start = async {
run_server(&config).await.unwrap(); run_server().await.unwrap();
}; };
if config.allow_jaeger { if config.allow_jaeger {

View file

@ -8,23 +8,15 @@ use ruma::{
use serde::{de::DeserializeOwned, Serialize}; use serde::{de::DeserializeOwned, Serialize};
use std::{collections::HashMap, sync::Arc}; use std::{collections::HashMap, sync::Arc};
use super::abstraction::Tree;
pub struct AccountData {
pub(super) roomuserdataid_accountdata: Arc<dyn Tree>, // RoomUserDataId = Room + User + Count + Type
pub(super) roomusertype_roomuserdataid: Arc<dyn Tree>, // RoomUserType = Room + User + Type
}
impl AccountData { impl AccountData {
/// Places one event in the account data of the user and removes the previous entry. /// Places one event in the account data of the user and removes the previous entry.
#[tracing::instrument(skip(self, room_id, user_id, event_type, data, globals))] #[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
pub fn update<T: Serialize>( pub fn update<T: Serialize>(
&self, &self,
room_id: Option<&RoomId>, room_id: Option<&RoomId>,
user_id: &UserId, user_id: &UserId,
event_type: RoomAccountDataEventType, event_type: RoomAccountDataEventType,
data: &T, data: &T,
globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let mut prefix = room_id let mut prefix = room_id
.map(|r| r.to_string()) .map(|r| r.to_string())
@ -36,7 +28,7 @@ impl AccountData {
prefix.push(0xff); prefix.push(0xff);
let mut roomuserdataid = prefix.clone(); let mut roomuserdataid = prefix.clone();
roomuserdataid.extend_from_slice(&globals.next_count()?.to_be_bytes()); roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
roomuserdataid.push(0xff); roomuserdataid.push(0xff);
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());

View file

@ -5,14 +5,6 @@ use std::{
time::Instant, time::Instant,
}; };
use crate::{
client_server::AUTO_GEN_PASSWORD_LENGTH,
error::{Error, Result},
pdu::PduBuilder,
server_server, utils,
utils::HtmlEscape,
Database, PduEvent,
};
use clap::Parser; use clap::Parser;
use regex::Regex; use regex::Regex;
use ruma::{ use ruma::{
@ -36,6 +28,10 @@ use ruma::{
use serde_json::value::to_raw_value; use serde_json::value::to_raw_value;
use tokio::sync::{mpsc, MutexGuard, RwLock, RwLockReadGuard}; use tokio::sync::{mpsc, MutexGuard, RwLock, RwLockReadGuard};
use crate::{services, Error, api::{server_server, client_server::AUTO_GEN_PASSWORD_LENGTH}, PduEvent, utils::{HtmlEscape, self}};
use super::pdu::PduBuilder;
#[derive(Debug)] #[derive(Debug)]
pub enum AdminRoomEvent { pub enum AdminRoomEvent {
ProcessMessage(String), ProcessMessage(String),
@ -50,22 +46,19 @@ pub struct Admin {
impl Admin { impl Admin {
pub fn start_handler( pub fn start_handler(
&self, &self,
db: Arc<RwLock<Database>>,
mut receiver: mpsc::UnboundedReceiver<AdminRoomEvent>, mut receiver: mpsc::UnboundedReceiver<AdminRoomEvent>,
) { ) {
tokio::spawn(async move { tokio::spawn(async move {
// TODO: Use futures when we have long admin commands // TODO: Use futures when we have long admin commands
//let mut futures = FuturesUnordered::new(); //let mut futures = FuturesUnordered::new();
let guard = db.read().await; let conduit_user = UserId::parse(format!("@conduit:{}", services().globals.server_name()))
let conduit_user = UserId::parse(format!("@conduit:{}", guard.globals.server_name()))
.expect("@conduit:server_name is valid"); .expect("@conduit:server_name is valid");
let conduit_room = guard let conduit_room = services()
.rooms .rooms
.id_from_alias( .id_from_alias(
format!("#admins:{}", guard.globals.server_name()) format!("#admins:{}", services().globals.server_name())
.as_str() .as_str()
.try_into() .try_into()
.expect("#admins:server_name is a valid room alias"), .expect("#admins:server_name is a valid room alias"),
@ -73,12 +66,9 @@ impl Admin {
.expect("Database data for admin room alias must be valid") .expect("Database data for admin room alias must be valid")
.expect("Admin room must exist"); .expect("Admin room must exist");
drop(guard);
let send_message = |message: RoomMessageEventContent, let send_message = |message: RoomMessageEventContent,
guard: RwLockReadGuard<'_, Database>,
mutex_lock: &MutexGuard<'_, ()>| { mutex_lock: &MutexGuard<'_, ()>| {
guard services()
.rooms .rooms
.build_and_append_pdu( .build_and_append_pdu(
PduBuilder { PduBuilder {
@ -91,7 +81,6 @@ impl Admin {
}, },
&conduit_user, &conduit_user,
&conduit_room, &conduit_room,
&guard,
mutex_lock, mutex_lock,
) )
.unwrap(); .unwrap();
@ -100,15 +89,13 @@ impl Admin {
loop { loop {
tokio::select! { tokio::select! {
Some(event) = receiver.recv() => { Some(event) = receiver.recv() => {
let guard = db.read().await;
let message_content = match event { let message_content = match event {
AdminRoomEvent::SendMessage(content) => content, AdminRoomEvent::SendMessage(content) => content,
AdminRoomEvent::ProcessMessage(room_message) => process_admin_message(&*guard, room_message).await AdminRoomEvent::ProcessMessage(room_message) => process_admin_message(room_message).await
}; };
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
guard.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -118,7 +105,7 @@ impl Admin {
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
send_message(message_content, guard, &state_lock); send_message(message_content, &state_lock);
drop(state_lock); drop(state_lock);
} }
@ -141,7 +128,7 @@ impl Admin {
} }
// Parse and process a message from the admin room // Parse and process a message from the admin room
async fn process_admin_message(db: &Database, room_message: String) -> RoomMessageEventContent { async fn process_admin_message(room_message: String) -> RoomMessageEventContent {
let mut lines = room_message.lines(); let mut lines = room_message.lines();
let command_line = lines.next().expect("each string has at least one line"); let command_line = lines.next().expect("each string has at least one line");
let body: Vec<_> = lines.collect(); let body: Vec<_> = lines.collect();
@ -149,7 +136,7 @@ async fn process_admin_message(db: &Database, room_message: String) -> RoomMessa
let admin_command = match parse_admin_command(&command_line) { let admin_command = match parse_admin_command(&command_line) {
Ok(command) => command, Ok(command) => command,
Err(error) => { Err(error) => {
let server_name = db.globals.server_name(); let server_name = services().globals.server_name();
let message = error let message = error
.to_string() .to_string()
.replace("server.name", server_name.as_str()); .replace("server.name", server_name.as_str());
@ -159,7 +146,7 @@ async fn process_admin_message(db: &Database, room_message: String) -> RoomMessa
} }
}; };
match process_admin_command(db, admin_command, body).await { match process_admin_command(admin_command, body).await {
Ok(reply_message) => reply_message, Ok(reply_message) => reply_message,
Err(error) => { Err(error) => {
let markdown_message = format!( let markdown_message = format!(
@ -322,7 +309,6 @@ enum AdminCommand {
} }
async fn process_admin_command( async fn process_admin_command(
db: &Database,
command: AdminCommand, command: AdminCommand,
body: Vec<&str>, body: Vec<&str>,
) -> Result<RoomMessageEventContent> { ) -> Result<RoomMessageEventContent> {
@ -332,7 +318,7 @@ async fn process_admin_command(
let appservice_config = body[1..body.len() - 1].join("\n"); let appservice_config = body[1..body.len() - 1].join("\n");
let parsed_config = serde_yaml::from_str::<serde_yaml::Value>(&appservice_config); let parsed_config = serde_yaml::from_str::<serde_yaml::Value>(&appservice_config);
match parsed_config { match parsed_config {
Ok(yaml) => match db.appservice.register_appservice(yaml) { Ok(yaml) => match services().appservice.register_appservice(yaml) {
Ok(id) => RoomMessageEventContent::text_plain(format!( Ok(id) => RoomMessageEventContent::text_plain(format!(
"Appservice registered with ID: {}.", "Appservice registered with ID: {}.",
id id
@ -355,7 +341,7 @@ async fn process_admin_command(
} }
AdminCommand::UnregisterAppservice { AdminCommand::UnregisterAppservice {
appservice_identifier, appservice_identifier,
} => match db.appservice.unregister_appservice(&appservice_identifier) { } => match services().appservice.unregister_appservice(&appservice_identifier) {
Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."),
Err(e) => RoomMessageEventContent::text_plain(format!( Err(e) => RoomMessageEventContent::text_plain(format!(
"Failed to unregister appservice: {}", "Failed to unregister appservice: {}",
@ -363,7 +349,7 @@ async fn process_admin_command(
)), )),
}, },
AdminCommand::ListAppservices => { AdminCommand::ListAppservices => {
if let Ok(appservices) = db.appservice.iter_ids().map(|ids| ids.collect::<Vec<_>>()) { if let Ok(appservices) = services().appservice.iter_ids().map(|ids| ids.collect::<Vec<_>>()) {
let count = appservices.len(); let count = appservices.len();
let output = format!( let output = format!(
"Appservices ({}): {}", "Appservices ({}): {}",
@ -380,14 +366,14 @@ async fn process_admin_command(
} }
} }
AdminCommand::ListRooms => { AdminCommand::ListRooms => {
let room_ids = db.rooms.iter_ids(); let room_ids = services().rooms.iter_ids();
let output = format!( let output = format!(
"Rooms:\n{}", "Rooms:\n{}",
room_ids room_ids
.filter_map(|r| r.ok()) .filter_map(|r| r.ok())
.map(|id| id.to_string() .map(|id| id.to_string()
+ "\tMembers: " + "\tMembers: "
+ &db + &services()
.rooms .rooms
.room_joined_count(&id) .room_joined_count(&id)
.ok() .ok()
@ -399,7 +385,7 @@ async fn process_admin_command(
); );
RoomMessageEventContent::text_plain(output) RoomMessageEventContent::text_plain(output)
} }
AdminCommand::ListLocalUsers => match db.users.list_local_users() { AdminCommand::ListLocalUsers => match services().users.list_local_users() {
Ok(users) => { Ok(users) => {
let mut msg: String = format!("Found {} local user account(s):\n", users.len()); let mut msg: String = format!("Found {} local user account(s):\n", users.len());
msg += &users.join("\n"); msg += &users.join("\n");
@ -408,7 +394,7 @@ async fn process_admin_command(
Err(e) => RoomMessageEventContent::text_plain(e.to_string()), Err(e) => RoomMessageEventContent::text_plain(e.to_string()),
}, },
AdminCommand::IncomingFederation => { AdminCommand::IncomingFederation => {
let map = db.globals.roomid_federationhandletime.read().unwrap(); let map = services().globals.roomid_federationhandletime.read().unwrap();
let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); let mut msg: String = format!("Handling {} incoming pdus:\n", map.len());
for (r, (e, i)) in map.iter() { for (r, (e, i)) in map.iter() {
@ -425,7 +411,7 @@ async fn process_admin_command(
} }
AdminCommand::GetAuthChain { event_id } => { AdminCommand::GetAuthChain { event_id } => {
let event_id = Arc::<EventId>::from(event_id); let event_id = Arc::<EventId>::from(event_id);
if let Some(event) = db.rooms.get_pdu_json(&event_id)? { if let Some(event) = services().rooms.get_pdu_json(&event_id)? {
let room_id_str = event let room_id_str = event
.get("room_id") .get("room_id")
.and_then(|val| val.as_str()) .and_then(|val| val.as_str())
@ -435,7 +421,7 @@ async fn process_admin_command(
Error::bad_database("Invalid room id field in event in database") Error::bad_database("Invalid room id field in event in database")
})?; })?;
let start = Instant::now(); let start = Instant::now();
let count = server_server::get_auth_chain(room_id, vec![event_id], db) let count = server_server::get_auth_chain(room_id, vec![event_id])
.await? .await?
.count(); .count();
let elapsed = start.elapsed(); let elapsed = start.elapsed();
@ -486,10 +472,10 @@ async fn process_admin_command(
} }
AdminCommand::GetPdu { event_id } => { AdminCommand::GetPdu { event_id } => {
let mut outlier = false; let mut outlier = false;
let mut pdu_json = db.rooms.get_non_outlier_pdu_json(&event_id)?; let mut pdu_json = services().rooms.get_non_outlier_pdu_json(&event_id)?;
if pdu_json.is_none() { if pdu_json.is_none() {
outlier = true; outlier = true;
pdu_json = db.rooms.get_pdu_json(&event_id)?; pdu_json = services().rooms.get_pdu_json(&event_id)?;
} }
match pdu_json { match pdu_json {
Some(json) => { Some(json) => {
@ -519,7 +505,7 @@ async fn process_admin_command(
None => RoomMessageEventContent::text_plain("PDU not found."), None => RoomMessageEventContent::text_plain("PDU not found."),
} }
} }
AdminCommand::DatabaseMemoryUsage => match db._db.memory_usage() { AdminCommand::DatabaseMemoryUsage => match services()._db.memory_usage() {
Ok(response) => RoomMessageEventContent::text_plain(response), Ok(response) => RoomMessageEventContent::text_plain(response),
Err(e) => RoomMessageEventContent::text_plain(format!( Err(e) => RoomMessageEventContent::text_plain(format!(
"Failed to get database memory usage: {}", "Failed to get database memory usage: {}",
@ -528,12 +514,12 @@ async fn process_admin_command(
}, },
AdminCommand::ShowConfig => { AdminCommand::ShowConfig => {
// Construct and send the response // Construct and send the response
RoomMessageEventContent::text_plain(format!("{}", db.globals.config)) RoomMessageEventContent::text_plain(format!("{}", services().globals.config))
} }
AdminCommand::ResetPassword { username } => { AdminCommand::ResetPassword { username } => {
let user_id = match UserId::parse_with_server_name( let user_id = match UserId::parse_with_server_name(
username.as_str().to_lowercase(), username.as_str().to_lowercase(),
db.globals.server_name(), services().globals.server_name(),
) { ) {
Ok(id) => id, Ok(id) => id,
Err(e) => { Err(e) => {
@ -545,10 +531,10 @@ async fn process_admin_command(
}; };
// Check if the specified user is valid // Check if the specified user is valid
if !db.users.exists(&user_id)? if !services().users.exists(&user_id)?
|| db.users.is_deactivated(&user_id)? || services().users.is_deactivated(&user_id)?
|| user_id || user_id
== UserId::parse_with_server_name("conduit", db.globals.server_name()) == UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("conduit user exists") .expect("conduit user exists")
{ {
return Ok(RoomMessageEventContent::text_plain( return Ok(RoomMessageEventContent::text_plain(
@ -558,7 +544,7 @@ async fn process_admin_command(
let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH);
match db.users.set_password(&user_id, Some(new_password.as_str())) { match services().users.set_password(&user_id, Some(new_password.as_str())) {
Ok(()) => RoomMessageEventContent::text_plain(format!( Ok(()) => RoomMessageEventContent::text_plain(format!(
"Successfully reset the password for user {}: {}", "Successfully reset the password for user {}: {}",
user_id, new_password user_id, new_password
@ -574,7 +560,7 @@ async fn process_admin_command(
// Validate user id // Validate user id
let user_id = match UserId::parse_with_server_name( let user_id = match UserId::parse_with_server_name(
username.as_str().to_lowercase(), username.as_str().to_lowercase(),
db.globals.server_name(), services().globals.server_name(),
) { ) {
Ok(id) => id, Ok(id) => id,
Err(e) => { Err(e) => {
@ -589,21 +575,21 @@ async fn process_admin_command(
"userid {user_id} is not allowed due to historical" "userid {user_id} is not allowed due to historical"
))); )));
} }
if db.users.exists(&user_id)? { if services().users.exists(&user_id)? {
return Ok(RoomMessageEventContent::text_plain(format!( return Ok(RoomMessageEventContent::text_plain(format!(
"userid {user_id} already exists" "userid {user_id} already exists"
))); )));
} }
// Create user // Create user
db.users.create(&user_id, Some(password.as_str()))?; services().users.create(&user_id, Some(password.as_str()))?;
// Default to pretty displayname // Default to pretty displayname
let displayname = format!("{} ⚡️", user_id.localpart()); let displayname = format!("{} ⚡️", user_id.localpart());
db.users services().users
.set_displayname(&user_id, Some(displayname.clone()))?; .set_displayname(&user_id, Some(displayname.clone()))?;
// Initial account data // Initial account data
db.account_data.update( services().account_data.update(
None, None,
&user_id, &user_id,
ruma::events::GlobalAccountDataEventType::PushRules ruma::events::GlobalAccountDataEventType::PushRules
@ -614,24 +600,21 @@ async fn process_admin_command(
global: ruma::push::Ruleset::server_default(&user_id), global: ruma::push::Ruleset::server_default(&user_id),
}, },
}, },
&db.globals,
)?; )?;
// we dont add a device since we're not the user, just the creator // we dont add a device since we're not the user, just the creator
db.flush()?;
// Inhibit login does not work for guests // Inhibit login does not work for guests
RoomMessageEventContent::text_plain(format!( RoomMessageEventContent::text_plain(format!(
"Created user with user_id: {user_id} and password: {password}" "Created user with user_id: {user_id} and password: {password}"
)) ))
} }
AdminCommand::DisableRoom { room_id } => { AdminCommand::DisableRoom { room_id } => {
db.rooms.disabledroomids.insert(room_id.as_bytes(), &[])?; services().rooms.disabledroomids.insert(room_id.as_bytes(), &[])?;
RoomMessageEventContent::text_plain("Room disabled.") RoomMessageEventContent::text_plain("Room disabled.")
} }
AdminCommand::EnableRoom { room_id } => { AdminCommand::EnableRoom { room_id } => {
db.rooms.disabledroomids.remove(room_id.as_bytes())?; services().rooms.disabledroomids.remove(room_id.as_bytes())?;
RoomMessageEventContent::text_plain("Room enabled.") RoomMessageEventContent::text_plain("Room enabled.")
} }
AdminCommand::DeactivateUser { AdminCommand::DeactivateUser {
@ -639,16 +622,16 @@ async fn process_admin_command(
user_id, user_id,
} => { } => {
let user_id = Arc::<UserId>::from(user_id); let user_id = Arc::<UserId>::from(user_id);
if db.users.exists(&user_id)? { if services().users.exists(&user_id)? {
RoomMessageEventContent::text_plain(format!( RoomMessageEventContent::text_plain(format!(
"Making {} leave all rooms before deactivation...", "Making {} leave all rooms before deactivation...",
user_id user_id
)); ));
db.users.deactivate_account(&user_id)?; services().users.deactivate_account(&user_id)?;
if leave_rooms { if leave_rooms {
db.rooms.leave_all_rooms(&user_id, &db).await?; services().rooms.leave_all_rooms(&user_id).await?;
} }
RoomMessageEventContent::text_plain(format!( RoomMessageEventContent::text_plain(format!(
@ -685,7 +668,7 @@ async fn process_admin_command(
if !force { if !force {
user_ids.retain(|&user_id| { user_ids.retain(|&user_id| {
match db.users.is_admin(user_id, &db.rooms, &db.globals) { match services().users.is_admin(user_id) {
Ok(is_admin) => match is_admin { Ok(is_admin) => match is_admin {
true => { true => {
admins.push(user_id.localpart()); admins.push(user_id.localpart());
@ -699,7 +682,7 @@ async fn process_admin_command(
} }
for &user_id in &user_ids { for &user_id in &user_ids {
match db.users.deactivate_account(user_id) { match services().users.deactivate_account(user_id) {
Ok(_) => deactivation_count += 1, Ok(_) => deactivation_count += 1,
Err(_) => {} Err(_) => {}
} }
@ -707,7 +690,7 @@ async fn process_admin_command(
if leave_rooms { if leave_rooms {
for &user_id in &user_ids { for &user_id in &user_ids {
let _ = db.rooms.leave_all_rooms(user_id, &db).await; let _ = services().rooms.leave_all_rooms(user_id).await;
} }
} }
@ -814,13 +797,13 @@ fn usage_to_html(text: &str, server_name: &ServerName) -> String {
/// ///
/// Users in this room are considered admins by conduit, and the room can be /// Users in this room are considered admins by conduit, and the room can be
/// used to issue admin commands by talking to the server user inside it. /// used to issue admin commands by talking to the server user inside it.
pub(crate) async fn create_admin_room(db: &Database) -> Result<()> { pub(crate) async fn create_admin_room() -> Result<()> {
let room_id = RoomId::new(db.globals.server_name()); let room_id = RoomId::new(services().globals.server_name());
db.rooms.get_or_create_shortroomid(&room_id, &db.globals)?; services().rooms.get_or_create_shortroomid(&room_id)?;
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -830,10 +813,10 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// Create a user for the server // Create a user for the server
let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("@conduit:server_name is valid"); .expect("@conduit:server_name is valid");
db.users.create(&conduit_user, None)?; services().users.create(&conduit_user, None)?;
let mut content = RoomCreateEventContent::new(conduit_user.clone()); let mut content = RoomCreateEventContent::new(conduit_user.clone());
content.federate = true; content.federate = true;
@ -841,7 +824,7 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
content.room_version = RoomVersionId::V6; content.room_version = RoomVersionId::V6;
// 1. The room create event // 1. The room create event
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomCreate, event_type: RoomEventType::RoomCreate,
content: to_raw_value(&content).expect("event is valid, we just created it"), content: to_raw_value(&content).expect("event is valid, we just created it"),
@ -851,12 +834,11 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// 2. Make conduit bot join // 2. Make conduit bot join
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomMember, event_type: RoomEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
@ -876,7 +858,6 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
@ -884,7 +865,7 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
let mut users = BTreeMap::new(); let mut users = BTreeMap::new();
users.insert(conduit_user.clone(), 100.into()); users.insert(conduit_user.clone(), 100.into());
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomPowerLevels, event_type: RoomEventType::RoomPowerLevels,
content: to_raw_value(&RoomPowerLevelsEventContent { content: to_raw_value(&RoomPowerLevelsEventContent {
@ -898,12 +879,11 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// 4.1 Join Rules // 4.1 Join Rules
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomJoinRules, event_type: RoomEventType::RoomJoinRules,
content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite))
@ -914,12 +894,11 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// 4.2 History Visibility // 4.2 History Visibility
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomHistoryVisibility, event_type: RoomEventType::RoomHistoryVisibility,
content: to_raw_value(&RoomHistoryVisibilityEventContent::new( content: to_raw_value(&RoomHistoryVisibilityEventContent::new(
@ -932,12 +911,11 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// 4.3 Guest Access // 4.3 Guest Access
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomGuestAccess, event_type: RoomEventType::RoomGuestAccess,
content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden)) content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden))
@ -948,14 +926,13 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// 5. Events implied by name and topic // 5. Events implied by name and topic
let room_name = RoomName::parse(format!("{} Admin Room", db.globals.server_name())) let room_name = RoomName::parse(format!("{} Admin Room", services().globals.server_name()))
.expect("Room name is valid"); .expect("Room name is valid");
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomName, event_type: RoomEventType::RoomName,
content: to_raw_value(&RoomNameEventContent::new(Some(room_name))) content: to_raw_value(&RoomNameEventContent::new(Some(room_name)))
@ -966,15 +943,14 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomTopic, event_type: RoomEventType::RoomTopic,
content: to_raw_value(&RoomTopicEventContent { content: to_raw_value(&RoomTopicEventContent {
topic: format!("Manage {}", db.globals.server_name()), topic: format!("Manage {}", services().globals.server_name()),
}) })
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
@ -983,16 +959,15 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// 6. Room alias // 6. Room alias
let alias: Box<RoomAliasId> = format!("#admins:{}", db.globals.server_name()) let alias: Box<RoomAliasId> = format!("#admins:{}", services().globals.server_name())
.try_into() .try_into()
.expect("#admins:server_name is a valid alias name"); .expect("#admins:server_name is a valid alias name");
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomCanonicalAlias, event_type: RoomEventType::RoomCanonicalAlias,
content: to_raw_value(&RoomCanonicalAliasEventContent { content: to_raw_value(&RoomCanonicalAliasEventContent {
@ -1006,11 +981,10 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
db.rooms.set_alias(&alias, Some(&room_id), &db.globals)?; services().rooms.set_alias(&alias, Some(&room_id))?;
Ok(()) Ok(())
} }
@ -1019,20 +993,19 @@ pub(crate) async fn create_admin_room(db: &Database) -> Result<()> {
/// ///
/// In conduit, this is equivalent to granting admin privileges. /// In conduit, this is equivalent to granting admin privileges.
pub(crate) async fn make_user_admin( pub(crate) async fn make_user_admin(
db: &Database,
user_id: &UserId, user_id: &UserId,
displayname: String, displayname: String,
) -> Result<()> { ) -> Result<()> {
let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", db.globals.server_name()) let admin_room_alias: Box<RoomAliasId> = format!("#admins:{}", services().globals.server_name())
.try_into() .try_into()
.expect("#admins:server_name is a valid alias name"); .expect("#admins:server_name is a valid alias name");
let room_id = db let room_id = services()
.rooms .rooms
.id_from_alias(&admin_room_alias)? .id_from_alias(&admin_room_alias)?
.expect("Admin room must exist"); .expect("Admin room must exist");
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services().globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -1042,11 +1015,11 @@ pub(crate) async fn make_user_admin(
let state_lock = mutex_state.lock().await; let state_lock = mutex_state.lock().await;
// Use the server user to grant the new admin's power level // Use the server user to grant the new admin's power level
let conduit_user = UserId::parse_with_server_name("conduit", db.globals.server_name()) let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
.expect("@conduit:server_name is valid"); .expect("@conduit:server_name is valid");
// Invite and join the real user // Invite and join the real user
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomMember, event_type: RoomEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
@ -1066,10 +1039,9 @@ pub(crate) async fn make_user_admin(
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomMember, event_type: RoomEventType::RoomMember,
content: to_raw_value(&RoomMemberEventContent { content: to_raw_value(&RoomMemberEventContent {
@ -1089,7 +1061,6 @@ pub(crate) async fn make_user_admin(
}, },
&user_id, &user_id,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
@ -1098,7 +1069,7 @@ pub(crate) async fn make_user_admin(
users.insert(conduit_user.to_owned(), 100.into()); users.insert(conduit_user.to_owned(), 100.into());
users.insert(user_id.to_owned(), 100.into()); users.insert(user_id.to_owned(), 100.into());
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomPowerLevels, event_type: RoomEventType::RoomPowerLevels,
content: to_raw_value(&RoomPowerLevelsEventContent { content: to_raw_value(&RoomPowerLevelsEventContent {
@ -1112,17 +1083,16 @@ pub(crate) async fn make_user_admin(
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;
// Send welcome message // Send welcome message
db.rooms.build_and_append_pdu( services().rooms.build_and_append_pdu(
PduBuilder { PduBuilder {
event_type: RoomEventType::RoomMessage, event_type: RoomEventType::RoomMessage,
content: to_raw_value(&RoomMessageEventContent::text_html( content: to_raw_value(&RoomMessageEventContent::text_html(
format!("## Thank you for trying out Conduit!\n\nConduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Website: https://conduit.rs\n> Git and Documentation: https://gitlab.com/famedly/conduit\n> Report issues: https://gitlab.com/famedly/conduit/-/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nConduit room (Ask questions and get notified on updates):\n`/join #conduit:fachschaften.org`\n\nConduit lounge (Off-topic, only Conduit users are allowed to join)\n`/join #conduit-lounge:conduit.rs`", db.globals.server_name()).to_owned(), format!("## Thank you for trying out Conduit!\n\nConduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.\n\nHelpful links:\n> Website: https://conduit.rs\n> Git and Documentation: https://gitlab.com/famedly/conduit\n> Report issues: https://gitlab.com/famedly/conduit/-/issues\n\nFor a list of available commands, send the following message in this room: `@conduit:{}: --help`\n\nHere are some rooms you can join (by typing the command):\n\nConduit room (Ask questions and get notified on updates):\n`/join #conduit:fachschaften.org`\n\nConduit lounge (Off-topic, only Conduit users are allowed to join)\n`/join #conduit-lounge:conduit.rs`", services().globals.server_name()).to_owned(),
format!("<h2>Thank you for trying out Conduit!</h2>\n<p>Conduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Website: https://conduit.rs<br>Git and Documentation: https://gitlab.com/famedly/conduit<br>Report issues: https://gitlab.com/famedly/conduit/-/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>Conduit room (Ask questions and get notified on updates):<br><code>/join #conduit:fachschaften.org</code></p>\n<p>Conduit lounge (Off-topic, only Conduit users are allowed to join)<br><code>/join #conduit-lounge:conduit.rs</code></p>\n", db.globals.server_name()).to_owned(), format!("<h2>Thank you for trying out Conduit!</h2>\n<p>Conduit is currently in Beta. This means you can join and participate in most Matrix rooms, but not all features are supported and you might run into bugs from time to time.</p>\n<p>Helpful links:</p>\n<blockquote>\n<p>Website: https://conduit.rs<br>Git and Documentation: https://gitlab.com/famedly/conduit<br>Report issues: https://gitlab.com/famedly/conduit/-/issues</p>\n</blockquote>\n<p>For a list of available commands, send the following message in this room: <code>@conduit:{}: --help</code></p>\n<p>Here are some rooms you can join (by typing the command):</p>\n<p>Conduit room (Ask questions and get notified on updates):<br><code>/join #conduit:fachschaften.org</code></p>\n<p>Conduit lounge (Off-topic, only Conduit users are allowed to join)<br><code>/join #conduit-lounge:conduit.rs</code></p>\n", services().globals.server_name()).to_owned(),
)) ))
.expect("event is valid, we just created it"), .expect("event is valid, we just created it"),
unsigned: None, unsigned: None,
@ -1131,7 +1101,6 @@ pub(crate) async fn make_user_admin(
}, },
&conduit_user, &conduit_user,
&room_id, &room_id,
&db,
&state_lock, &state_lock,
)?; )?;

View file

@ -1,17 +1,18 @@
pub trait Data { pub trait Data {
type Iter: Iterator;
/// Registers an appservice and returns the ID to the caller /// Registers an appservice and returns the ID to the caller
pub fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String>; fn register_appservice(&self, yaml: serde_yaml::Value) -> Result<String>;
/// Remove an appservice registration /// Remove an appservice registration
/// ///
/// # Arguments /// # Arguments
/// ///
/// * `service_name` - the name you send to register the service previously /// * `service_name` - the name you send to register the service previously
pub fn unregister_appservice(&self, service_name: &str) -> Result<()>; fn unregister_appservice(&self, service_name: &str) -> Result<()>;
pub fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>>; fn get_registration(&self, id: &str) -> Result<Option<serde_yaml::Value>>;
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_>; fn iter_ids(&self) -> Result<Self::Iter<Item = Result<String>>>;
pub fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>>; fn all(&self) -> Result<Vec<(String, serde_yaml::Value)>>;
} }

View file

@ -1,4 +1,4 @@
use crate::{utils, Error, Result}; use crate::{utils, Error, Result, services};
use ruma::{ use ruma::{
api::client::{ api::client::{
backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
@ -9,22 +9,13 @@ use ruma::{
}; };
use std::{collections::BTreeMap, sync::Arc}; use std::{collections::BTreeMap, sync::Arc};
use super::abstraction::Tree;
pub struct KeyBackups {
pub(super) backupid_algorithm: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
pub(super) backupid_etag: Arc<dyn Tree>, // BackupId = UserId + Version(Count)
pub(super) backupkeyid_backup: Arc<dyn Tree>, // BackupKeyId = UserId + Version + RoomId + SessionId
}
impl KeyBackups { impl KeyBackups {
pub fn create_backup( pub fn create_backup(
&self, &self,
user_id: &UserId, user_id: &UserId,
backup_metadata: &Raw<BackupAlgorithm>, backup_metadata: &Raw<BackupAlgorithm>,
globals: &super::globals::Globals,
) -> Result<String> { ) -> Result<String> {
let version = globals.next_count()?.to_string(); let version = services().globals.next_count()?.to_string();
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
@ -35,7 +26,7 @@ impl KeyBackups {
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
)?; )?;
self.backupid_etag self.backupid_etag
.insert(&key, &globals.next_count()?.to_be_bytes())?; .insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version) Ok(version)
} }
@ -61,7 +52,6 @@ impl KeyBackups {
user_id: &UserId, user_id: &UserId,
version: &str, version: &str,
backup_metadata: &Raw<BackupAlgorithm>, backup_metadata: &Raw<BackupAlgorithm>,
globals: &super::globals::Globals,
) -> Result<String> { ) -> Result<String> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
@ -77,7 +67,7 @@ impl KeyBackups {
self.backupid_algorithm self.backupid_algorithm
.insert(&key, backup_metadata.json().get().as_bytes())?; .insert(&key, backup_metadata.json().get().as_bytes())?;
self.backupid_etag self.backupid_etag
.insert(&key, &globals.next_count()?.to_be_bytes())?; .insert(&key, &services().globals.next_count()?.to_be_bytes())?;
Ok(version.to_owned()) Ok(version.to_owned())
} }
@ -157,7 +147,6 @@ impl KeyBackups {
room_id: &RoomId, room_id: &RoomId,
session_id: &str, session_id: &str,
key_data: &Raw<KeyBackupData>, key_data: &Raw<KeyBackupData>,
globals: &super::globals::Globals,
) -> Result<()> { ) -> Result<()> {
let mut key = user_id.as_bytes().to_vec(); let mut key = user_id.as_bytes().to_vec();
key.push(0xff); key.push(0xff);
@ -171,7 +160,7 @@ impl KeyBackups {
} }
self.backupid_etag self.backupid_etag
.insert(&key, &globals.next_count()?.to_be_bytes())?; .insert(&key, &services().globals.next_count()?.to_be_bytes())?;
key.push(0xff); key.push(0xff);
key.extend_from_slice(room_id.as_bytes()); key.extend_from_slice(room_id.as_bytes());

View file

@ -1,4 +1,3 @@
use crate::database::globals::Globals;
use image::{imageops::FilterType, GenericImageView}; use image::{imageops::FilterType, GenericImageView};
use super::abstraction::Tree; use super::abstraction::Tree;

28
src/service/mod.rs Normal file
View file

@ -0,0 +1,28 @@
pub mod pdu;
pub mod appservice;
pub mod pusher;
pub mod rooms;
pub mod transaction_ids;
pub mod uiaa;
pub mod users;
pub mod account_data;
pub mod admin;
pub mod globals;
pub mod key_backups;
pub mod media;
pub mod sending;
pub struct Services<D> {
pub appservice: appservice::Service<D>,
pub pusher: pusher::Service<D>,
pub rooms: rooms::Service<D>,
pub transaction_ids: transaction_ids::Service<D>,
pub uiaa: uiaa::Service<D>,
pub users: users::Service<D>,
//pub account_data: account_data::Service<D>,
//pub admin: admin::Service<D>,
pub globals: globals::Service<D>,
//pub key_backups: key_backups::Service<D>,
//pub media: media::Service<D>,
//pub sending: sending::Service<D>,
}

View file

@ -1,4 +1,4 @@
use crate::{Database, Error}; use crate::{Database, Error, services};
use ruma::{ use ruma::{
events::{ events::{
room::member::RoomMemberEventContent, AnyEphemeralRoomEvent, AnyRoomEvent, AnyStateEvent, room::member::RoomMemberEventContent, AnyEphemeralRoomEvent, AnyRoomEvent, AnyStateEvent,
@ -332,7 +332,6 @@ impl Ord for PduEvent {
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, CanonicalJsonValue>`. /// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, CanonicalJsonValue>`.
pub(crate) fn gen_event_id_canonical_json( pub(crate) fn gen_event_id_canonical_json(
pdu: &RawJsonValue, pdu: &RawJsonValue,
db: &Database,
) -> crate::Result<(Box<EventId>, CanonicalJsonObject)> { ) -> crate::Result<(Box<EventId>, CanonicalJsonObject)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
warn!("Error parsing incoming event {:?}: {:?}", pdu, e); warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
@ -344,7 +343,7 @@ pub(crate) fn gen_event_id_canonical_json(
.and_then(|id| RoomId::parse(id.as_str()?).ok()) .and_then(|id| RoomId::parse(id.as_str()?).ok())
.ok_or_else(|| Error::bad_database("PDU in db has invalid room_id."))?; .ok_or_else(|| Error::bad_database("PDU in db has invalid room_id."))?;
let room_version_id = db.rooms.get_room_version(&room_id); let room_version_id = services().rooms.get_room_version(&room_id);
let event_id = format!( let event_id = format!(
"${}", "${}",

View file

@ -1,11 +1,13 @@
use ruma::{UserId, api::client::push::{set_pusher, get_pushers}};
pub trait Data { pub trait Data {
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()>; fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::Pusher) -> Result<()>;
pub fn get_pusher(&self, senderkey: &[u8]) -> Result<Option<get_pushers::v3::Pusher>>; fn get_pusher(&self, senderkey: &[u8]) -> Result<Option<get_pushers::v3::Pusher>>;
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<get_pushers::v3::Pusher>>; fn get_pushers(&self, sender: &UserId) -> Result<Vec<get_pushers::v3::Pusher>>;
pub fn get_pusher_senderkeys<'a>( fn get_pusher_senderkeys<'a>(
&'a self, &'a self,
sender: &UserId, sender: &UserId,
) -> impl Iterator<Item = Vec<u8>> + 'a; ) -> impl Iterator<Item = Vec<u8>> + 'a;

View file

@ -1,7 +1,27 @@
mod data; mod data;
pub use data::Data; pub use data::Data;
use crate::service::*; use crate::{services, Error, PduEvent};
use bytes::BytesMut;
use ruma::{
api::{
client::push::{get_pushers, set_pusher, PusherKind},
push_gateway::send_event_notification::{
self,
v1::{Device, Notification, NotificationCounts, NotificationPriority},
},
MatrixVersion, OutgoingRequest, SendAccessToken,
},
events::{
room::{name::RoomNameEventContent, power_levels::RoomPowerLevelsEventContent},
AnySyncRoomEvent, RoomEventType, StateEventType,
},
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
serde::Raw,
uint, RoomId, UInt, UserId,
};
use std::{fmt::Debug, mem};
use tracing::{error, info, warn};
pub struct Service<D: Data> { pub struct Service<D: Data> {
db: D, db: D,
@ -27,9 +47,8 @@ impl Service<_> {
self.db.get_pusher_senderkeys(sender) self.db.get_pusher_senderkeys(sender)
} }
#[tracing::instrument(skip(globals, destination, request))] #[tracing::instrument(skip(destination, request))]
pub async fn send_request<T: OutgoingRequest>( pub async fn send_request<T: OutgoingRequest>(
globals: &crate::database::globals::Globals,
destination: &str, destination: &str,
request: T, request: T,
) -> Result<T::IncomingResponse> ) -> Result<T::IncomingResponse>
@ -57,7 +76,7 @@ impl Service<_> {
//*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
let url = reqwest_request.url().clone(); let url = reqwest_request.url().clone();
let response = globals.default_client().execute(reqwest_request).await; let response = services().globals.default_client().execute(reqwest_request).await;
match response { match response {
Ok(mut response) => { Ok(mut response) => {
@ -105,19 +124,19 @@ impl Service<_> {
} }
} }
#[tracing::instrument(skip(user, unread, pusher, ruleset, pdu, db))] #[tracing::instrument(skip(user, unread, pusher, ruleset, pdu))]
pub async fn send_push_notice( pub async fn send_push_notice(
&self,
user: &UserId, user: &UserId,
unread: UInt, unread: UInt,
pusher: &get_pushers::v3::Pusher, pusher: &get_pushers::v3::Pusher,
ruleset: Ruleset, ruleset: Ruleset,
pdu: &PduEvent, pdu: &PduEvent,
db: &Database,
) -> Result<()> { ) -> Result<()> {
let mut notify = None; let mut notify = None;
let mut tweaks = Vec::new(); let mut tweaks = Vec::new();
let power_levels: RoomPowerLevelsEventContent = db let power_levels: RoomPowerLevelsEventContent = services()
.rooms .rooms
.room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")?
.map(|ev| { .map(|ev| {
@ -127,13 +146,12 @@ impl Service<_> {
.transpose()? .transpose()?
.unwrap_or_default(); .unwrap_or_default();
for action in get_actions( for action in self.get_actions(
user, user,
&ruleset, &ruleset,
&power_levels, &power_levels,
&pdu.to_sync_room_event(), &pdu.to_sync_room_event(),
&pdu.room_id, &pdu.room_id,
db,
)? { )? {
let n = match action { let n = match action {
Action::DontNotify => false, Action::DontNotify => false,
@ -155,27 +173,26 @@ impl Service<_> {
} }
if notify == Some(true) { if notify == Some(true) {
send_notice(unread, pusher, tweaks, pdu, db).await?; self.send_notice(unread, pusher, tweaks, pdu).await?;
} }
// Else the event triggered no actions // Else the event triggered no actions
Ok(()) Ok(())
} }
#[tracing::instrument(skip(user, ruleset, pdu, db))] #[tracing::instrument(skip(user, ruleset, pdu))]
pub fn get_actions<'a>( pub fn get_actions<'a>(
&self,
user: &UserId, user: &UserId,
ruleset: &'a Ruleset, ruleset: &'a Ruleset,
power_levels: &RoomPowerLevelsEventContent, power_levels: &RoomPowerLevelsEventContent,
pdu: &Raw<AnySyncRoomEvent>, pdu: &Raw<AnySyncRoomEvent>,
room_id: &RoomId, room_id: &RoomId,
db: &Database,
) -> Result<&'a [Action]> { ) -> Result<&'a [Action]> {
let ctx = PushConditionRoomCtx { let ctx = PushConditionRoomCtx {
room_id: room_id.to_owned(), room_id: room_id.to_owned(),
member_count: 10_u32.into(), // TODO: get member count efficiently member_count: 10_u32.into(), // TODO: get member count efficiently
user_display_name: db user_display_name: services().users
.users
.displayname(user)? .displayname(user)?
.unwrap_or_else(|| user.localpart().to_owned()), .unwrap_or_else(|| user.localpart().to_owned()),
users_power_levels: power_levels.users.clone(), users_power_levels: power_levels.users.clone(),
@ -186,13 +203,13 @@ impl Service<_> {
Ok(ruleset.get_actions(pdu, &ctx)) Ok(ruleset.get_actions(pdu, &ctx))
} }
#[tracing::instrument(skip(unread, pusher, tweaks, event, db))] #[tracing::instrument(skip(unread, pusher, tweaks, event))]
async fn send_notice( async fn send_notice(
&self,
unread: UInt, unread: UInt,
pusher: &get_pushers::v3::Pusher, pusher: &get_pushers::v3::Pusher,
tweaks: Vec<Tweak>, tweaks: Vec<Tweak>,
event: &PduEvent, event: &PduEvent,
db: &Database,
) -> Result<()> { ) -> Result<()> {
// TODO: email // TODO: email
if pusher.kind == PusherKind::Email { if pusher.kind == PusherKind::Email {
@ -240,12 +257,8 @@ impl Service<_> {
} }
if event_id_only { if event_id_only {
send_request( self.send_request(url, send_event_notification::v1::Request::new(notifi))
&db.globals, .await?;
url,
send_event_notification::v1::Request::new(notifi),
)
.await?;
} else { } else {
notifi.sender = Some(&event.sender); notifi.sender = Some(&event.sender);
notifi.event_type = Some(&event.kind); notifi.event_type = Some(&event.kind);
@ -256,11 +269,11 @@ impl Service<_> {
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
} }
let user_name = db.users.displayname(&event.sender)?; let user_name = services().users.displayname(&event.sender)?;
notifi.sender_display_name = user_name.as_deref(); notifi.sender_display_name = user_name.as_deref();
let room_name = if let Some(room_name_pdu) = let room_name = if let Some(room_name_pdu) =
db.rooms services().rooms
.room_state_get(&event.room_id, &StateEventType::RoomName, "")? .room_state_get(&event.room_id, &StateEventType::RoomName, "")?
{ {
serde_json::from_str::<RoomNameEventContent>(room_name_pdu.content.get()) serde_json::from_str::<RoomNameEventContent>(room_name_pdu.content.get())
@ -272,8 +285,7 @@ impl Service<_> {
notifi.room_name = room_name.as_deref(); notifi.room_name = room_name.as_deref();
send_request( self.send_request(
&db.globals,
url, url,
send_event_notification::v1::Request::new(notifi), send_event_notification::v1::Request::new(notifi),
) )

View file

@ -1,22 +1,24 @@
use ruma::{RoomId, RoomAliasId};
pub trait Data { pub trait Data {
/// Creates or updates the alias to the given room id. /// Creates or updates the alias to the given room id.
pub fn set_alias( fn set_alias(
alias: &RoomAliasId, alias: &RoomAliasId,
room_id: &RoomId room_id: &RoomId
) -> Result<()>; ) -> Result<()>;
/// Forgets about an alias. Returns an error if the alias did not exist. /// Forgets about an alias. Returns an error if the alias did not exist.
pub fn remove_alias( fn remove_alias(
alias: &RoomAliasId, alias: &RoomAliasId,
) -> Result<()>; ) -> Result<()>;
/// Looks up the roomid for the given alias. /// Looks up the roomid for the given alias.
pub fn resolve_local_alias( fn resolve_local_alias(
alias: &RoomAliasId, alias: &RoomAliasId,
) -> Result<()>; ) -> Result<()>;
/// Returns all local aliases that point to the given room /// Returns all local aliases that point to the given room
pub fn local_aliases_for_room( fn local_aliases_for_room(
alias: &RoomAliasId, alias: &RoomAliasId,
) -> Result<()>; ) -> Result<()>;
} }

View file

@ -1,14 +1,13 @@
mod data; mod data;
pub use data::Data; pub use data::Data;
use ruma::{RoomAliasId, RoomId};
use crate::service::*;
pub struct Service<D: Data> { pub struct Service<D: Data> {
db: D, db: D,
} }
impl Service<_> { impl Service<_> {
#[tracing::instrument(skip(self, globals))] #[tracing::instrument(skip(self))]
pub fn set_alias( pub fn set_alias(
&self, &self,
alias: &RoomAliasId, alias: &RoomAliasId,
@ -17,7 +16,7 @@ impl Service<_> {
self.db.set_alias(alias, room_id) self.db.set_alias(alias, room_id)
} }
#[tracing::instrument(skip(self, globals))] #[tracing::instrument(skip(self))]
pub fn remove_alias( pub fn remove_alias(
&self, &self,
alias: &RoomAliasId, alias: &RoomAliasId,

View file

@ -1,3 +1,5 @@
use std::collections::HashSet;
pub trait Data { pub trait Data {
fn get_cached_eventid_authchain<'a>() -> Result<HashSet<u64>>; fn get_cached_eventid_authchain<'a>() -> Result<HashSet<u64>>;
fn cache_eventid_authchain<'a>(shorteventid: u64, auth_chain: &HashSet<u64>) -> Result<HashSet<u64>>; fn cache_eventid_authchain<'a>(shorteventid: u64, auth_chain: &HashSet<u64>) -> Result<HashSet<u64>>;

View file

@ -1,4 +1,6 @@
mod data; mod data;
use std::{sync::Arc, collections::HashSet};
pub use data::Data; pub use data::Data;
use crate::service::*; use crate::service::*;

View file

@ -1,3 +1,5 @@
use ruma::RoomId;
pub trait Data { pub trait Data {
/// Adds the room to the public room directory /// Adds the room to the public room directory
fn set_public(room_id: &RoomId) -> Result<()>; fn set_public(room_id: &RoomId) -> Result<()>;

View file

@ -1,5 +1,6 @@
mod data; mod data;
pub use data::Data; pub use data::Data;
use ruma::RoomId;
use crate::service::*; use crate::service::*;
@ -10,21 +11,21 @@ pub struct Service<D: Data> {
impl Service<_> { impl Service<_> {
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn set_public(&self, room_id: &RoomId) -> Result<()> { pub fn set_public(&self, room_id: &RoomId) -> Result<()> {
self.db.set_public(&self, room_id) self.db.set_public(room_id)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
self.db.set_not_public(&self, room_id) self.db.set_not_public(room_id)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
self.db.is_public_room(&self, room_id) self.db.is_public_room(room_id)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn public_rooms(&self) -> impl Iterator<Item = Result<Box<RoomId>>> + '_ { pub fn public_rooms(&self) -> impl Iterator<Item = Result<Box<RoomId>>> + '_ {
self.db.public_rooms(&self, room_id) self.db.public_rooms()
} }
} }

View file

@ -1,3 +1,9 @@
pub mod presence; pub mod presence;
pub mod read_receipt; pub mod read_receipt;
pub mod typing; pub mod typing;
pub struct Service<D> {
presence: presence::Service<D>,
read_receipt: read_receipt::Service<D>,
typing: typing::Service<D>,
}

View file

@ -1,3 +1,7 @@
use std::collections::HashMap;
use ruma::{UserId, RoomId, events::presence::PresenceEvent};
pub trait Data { pub trait Data {
/// Adds a presence event which will be saved until a new event replaces it. /// Adds a presence event which will be saved until a new event replaces it.
/// ///

View file

@ -1,5 +1,8 @@
mod data; mod data;
use std::collections::HashMap;
pub use data::Data; pub use data::Data;
use ruma::{RoomId, UserId, events::presence::PresenceEvent};
use crate::service::*; use crate::service::*;
@ -108,7 +111,7 @@ impl Service<_> {
}*/ }*/
/// Returns the most recent presence updates that happened after the event with id `since`. /// Returns the most recent presence updates that happened after the event with id `since`.
#[tracing::instrument(skip(self, since, _rooms, _globals))] #[tracing::instrument(skip(self, since, room_id))]
pub fn presence_since( pub fn presence_since(
&self, &self,
room_id: &RoomId, room_id: &RoomId,

View file

@ -1,3 +1,5 @@
use ruma::{RoomId, events::receipt::ReceiptEvent, UserId, serde::Raw};
pub trait Data { pub trait Data {
/// Replaces the previous read receipt. /// Replaces the previous read receipt.
fn readreceipt_update( fn readreceipt_update(

View file

@ -1,7 +1,6 @@
mod data; mod data;
pub use data::Data; pub use data::Data;
use ruma::{RoomId, UserId, events::receipt::ReceiptEvent, serde::Raw};
use crate::service::*;
pub struct Service<D: Data> { pub struct Service<D: Data> {
db: D, db: D,
@ -15,7 +14,7 @@ impl Service<_> {
room_id: &RoomId, room_id: &RoomId,
event: ReceiptEvent, event: ReceiptEvent,
) -> Result<()> { ) -> Result<()> {
self.db.readreceipt_update(user_id, room_id, event); self.db.readreceipt_update(user_id, room_id, event)
} }
/// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
@ -35,7 +34,7 @@ impl Service<_> {
} }
/// Sets a private read marker at `count`. /// Sets a private read marker at `count`.
#[tracing::instrument(skip(self, globals))] #[tracing::instrument(skip(self))]
pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
self.db.private_read_set(room_id, user_id, count) self.db.private_read_set(room_id, user_id, count)
} }

View file

@ -1,3 +1,7 @@
use std::collections::HashSet;
use ruma::{UserId, RoomId};
pub trait Data { pub trait Data {
/// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is
/// called. /// called.

View file

@ -1,5 +1,6 @@
mod data; mod data;
pub use data::Data; pub use data::Data;
use ruma::{UserId, RoomId};
use crate::service::*; use crate::service::*;
@ -66,7 +67,6 @@ impl Service<_> {
*/ */
/// Returns the count of the last typing update in this room. /// Returns the count of the last typing update in this room.
#[tracing::instrument(skip(self, globals))]
pub fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> { pub fn last_typing_update(&self, room_id: &RoomId) -> Result<u64> {
self.db.last_typing_update(room_id) self.db.last_typing_update(room_id)
} }

View file

@ -1,8 +1,29 @@
/// An async function that can recursively call itself. /// An async function that can recursively call itself.
type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>; type AsyncRecursiveType<'a, T> = Pin<Box<dyn Future<Output = T> + 'a + Send>>;
use crate::service::*; use std::{
collections::{btree_map, hash_map, BTreeMap, HashMap, HashSet},
pin::Pin,
sync::{Arc, RwLock},
time::{Duration, Instant},
};
use futures_util::Future;
use ruma::{
api::{
client::error::ErrorKind,
federation::event::{get_event, get_room_state_ids},
},
events::{room::create::RoomCreateEventContent, StateEventType},
int,
serde::Base64,
signatures::CanonicalJsonValue,
state_res::{self, RoomVersion, StateMap},
uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName,
};
use tracing::{error, info, trace, warn};
use crate::{service::*, services, Error, PduEvent};
pub struct Service; pub struct Service;
@ -31,45 +52,47 @@ impl Service {
/// it /// it
/// 14. Use state resolution to find new room state /// 14. Use state resolution to find new room state
// We use some AsyncRecursiveType hacks here so we can call this async funtion recursively // We use some AsyncRecursiveType hacks here so we can call this async funtion recursively
#[tracing::instrument(skip(value, is_timeline_event, db, pub_key_map))] #[tracing::instrument(skip(value, is_timeline_event, pub_key_map))]
pub(crate) async fn handle_incoming_pdu<'a>( pub(crate) async fn handle_incoming_pdu<'a>(
&self,
origin: &'a ServerName, origin: &'a ServerName,
event_id: &'a EventId, event_id: &'a EventId,
room_id: &'a RoomId, room_id: &'a RoomId,
value: BTreeMap<String, CanonicalJsonValue>, value: BTreeMap<String, CanonicalJsonValue>,
is_timeline_event: bool, is_timeline_event: bool,
db: &'a Database,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>> { ) -> Result<Option<Vec<u8>>> {
db.rooms.exists(room_id)?.ok_or(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server"))?; services().rooms.exists(room_id)?.ok_or(Error::BadRequest(
ErrorKind::NotFound,
"Room is unknown to this server",
))?;
services()
.rooms
.is_disabled(room_id)?
.ok_or(Error::BadRequest(
ErrorKind::Forbidden,
"Federation of this room is currently disabled on this server.",
))?;
db.rooms.is_disabled(room_id)?.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Federation of this room is currently disabled on this server."))?;
// 1. Skip the PDU if we already have it as a timeline event // 1. Skip the PDU if we already have it as a timeline event
if let Some(pdu_id) = db.rooms.get_pdu_id(event_id)? { if let Some(pdu_id) = services().rooms.get_pdu_id(event_id)? {
return Some(pdu_id.to_vec()); return Ok(Some(pdu_id.to_vec()));
} }
let create_event = db let create_event = services()
.rooms .rooms
.room_state_get(room_id, &StateEventType::RoomCreate, "")? .room_state_get(room_id, &StateEventType::RoomCreate, "")?
.ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?;
let first_pdu_in_room = db let first_pdu_in_room = services()
.rooms .rooms
.first_pdu_in_room(room_id)? .first_pdu_in_room(room_id)?
.ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?;
let (incoming_pdu, val) = handle_outlier_pdu( let (incoming_pdu, val) = self
origin, .handle_outlier_pdu(origin, &create_event, event_id, room_id, value, pub_key_map)
&create_event, .await?;
event_id,
room_id,
value,
db,
pub_key_map,
)
.await?;
// 8. if not timeline event: stop // 8. if not timeline event: stop
if !is_timeline_event { if !is_timeline_event {
@ -82,15 +105,27 @@ impl Service {
} }
// 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events
let sorted_prev_events = fetch_unknown_prev_events(incoming_pdu.prev_events.clone()); let (sorted_prev_events, eventid_info) = self.fetch_unknown_prev_events(
origin,
&create_event,
room_id,
pub_key_map,
incoming_pdu.prev_events.clone(),
);
let mut errors = 0; let mut errors = 0;
for prev_id in dbg!(sorted) { for prev_id in dbg!(sorted_prev_events) {
// Check for disabled again because it might have changed // Check for disabled again because it might have changed
db.rooms.is_disabled(room_id)?.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Federation of services()
this room is currently disabled on this server."))?; .rooms
.is_disabled(room_id)?
.ok_or(Error::BadRequest(
ErrorKind::Forbidden,
"Federation of
this room is currently disabled on this server.",
))?;
if let Some((time, tries)) = db if let Some((time, tries)) = services()
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.read() .read()
@ -120,26 +155,27 @@ impl Service {
} }
let start_time = Instant::now(); let start_time = Instant::now();
db.globals services()
.globals
.roomid_federationhandletime .roomid_federationhandletime
.write() .write()
.unwrap() .unwrap()
.insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time));
if let Err(e) = upgrade_outlier_to_timeline_pdu( if let Err(e) = self
pdu, .upgrade_outlier_to_timeline_pdu(
json, pdu,
&create_event, json,
origin, &create_event,
db, origin,
room_id, room_id,
pub_key_map, pub_key_map,
) )
.await .await
{ {
errors += 1; errors += 1;
warn!("Prev event {} failed: {}", prev_id, e); warn!("Prev event {} failed: {}", prev_id, e);
match db match services()
.globals .globals
.bad_event_ratelimiter .bad_event_ratelimiter
.write() .write()
@ -155,7 +191,8 @@ impl Service {
} }
} }
let elapsed = start_time.elapsed(); let elapsed = start_time.elapsed();
db.globals services()
.globals
.roomid_federationhandletime .roomid_federationhandletime
.write() .write()
.unwrap() .unwrap()
@ -172,22 +209,23 @@ impl Service {
// Done with prev events, now handling the incoming event // Done with prev events, now handling the incoming event
let start_time = Instant::now(); let start_time = Instant::now();
db.globals services()
.globals
.roomid_federationhandletime .roomid_federationhandletime
.write() .write()
.unwrap() .unwrap()
.insert(room_id.to_owned(), (event_id.to_owned(), start_time)); .insert(room_id.to_owned(), (event_id.to_owned(), start_time));
let r = upgrade_outlier_to_timeline_pdu( let r = services().rooms.event_handler.upgrade_outlier_to_timeline_pdu(
incoming_pdu, incoming_pdu,
val, val,
&create_event, &create_event,
origin, origin,
db,
room_id, room_id,
pub_key_map, pub_key_map,
) )
.await; .await;
db.globals services()
.globals
.roomid_federationhandletime .roomid_federationhandletime
.write() .write()
.unwrap() .unwrap()
@ -196,22 +234,23 @@ impl Service {
r r
} }
#[tracing::instrument(skip(create_event, value, db, pub_key_map))] #[tracing::instrument(skip(create_event, value, pub_key_map))]
fn handle_outlier_pdu<'a>( fn handle_outlier_pdu<'a>(
&self,
origin: &'a ServerName, origin: &'a ServerName,
create_event: &'a PduEvent, create_event: &'a PduEvent,
event_id: &'a EventId, event_id: &'a EventId,
room_id: &'a RoomId, room_id: &'a RoomId,
value: BTreeMap<String, CanonicalJsonValue>, value: BTreeMap<String, CanonicalJsonValue>,
db: &'a Database,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>), String>> { ) -> AsyncRecursiveType<'a, Result<(Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>), String>>
{
Box::pin(async move { Box::pin(async move {
// TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json
// We go through all the signatures we see on the value and fetch the corresponding signing // We go through all the signatures we see on the value and fetch the corresponding signing
// keys // keys
fetch_required_signing_keys(&value, pub_key_map, db) self.fetch_required_signing_keys(&value, pub_key_map, db)
.await?; .await?;
// 2. Check signatures, otherwise drop // 2. Check signatures, otherwise drop
@ -223,7 +262,8 @@ impl Service {
})?; })?;
let room_version_id = &create_event_content.room_version; let room_version_id = &create_event_content.room_version;
let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); let room_version =
RoomVersion::new(room_version_id).expect("room version is supported");
let mut val = match ruma::signatures::verify_event( let mut val = match ruma::signatures::verify_event(
&*pub_key_map.read().map_err(|_| "RwLock is poisoned.")?, &*pub_key_map.read().map_err(|_| "RwLock is poisoned.")?,
@ -261,8 +301,7 @@ impl Service {
// 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events"
// NOTE: Step 5 is not applied anymore because it failed too often // NOTE: Step 5 is not applied anymore because it failed too often
warn!("Fetching auth events for {}", incoming_pdu.event_id); warn!("Fetching auth events for {}", incoming_pdu.event_id);
fetch_and_handle_outliers( self.fetch_and_handle_outliers(
db,
origin, origin,
&incoming_pdu &incoming_pdu
.auth_events .auth_events
@ -284,7 +323,7 @@ impl Service {
// Build map of auth events // Build map of auth events
let mut auth_events = HashMap::new(); let mut auth_events = HashMap::new();
for id in &incoming_pdu.auth_events { for id in &incoming_pdu.auth_events {
let auth_event = match db.rooms.get_pdu(id)? { let auth_event = match services().rooms.get_pdu(id)? {
Some(e) => e, Some(e) => e,
None => { None => {
warn!("Could not find auth event {}", id); warn!("Could not find auth event {}", id);
@ -303,8 +342,9 @@ impl Service {
v.insert(auth_event); v.insert(auth_event);
} }
hash_map::Entry::Occupied(_) => { hash_map::Entry::Occupied(_) => {
return Err(Error::BadRequest(ErrorKind::InvalidParam, return Err(Error::BadRequest(
"Auth event's type and state_key combination exists multiple times." ErrorKind::InvalidParam,
"Auth event's type and state_key combination exists multiple times.",
)); ));
} }
} }
@ -316,7 +356,10 @@ impl Service {
.map(|a| a.as_ref()) .map(|a| a.as_ref())
!= Some(create_event) != Some(create_event)
{ {
return Err(Error::BadRequest(ErrorKind::InvalidParam("Incoming event refers to wrong create event."))); return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Incoming event refers to wrong create event.",
));
} }
if !state_res::event_auth::auth_check( if !state_res::event_auth::auth_check(
@ -325,15 +368,21 @@ impl Service {
None::<PduEvent>, // TODO: third party invite None::<PduEvent>, // TODO: third party invite
|k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())),
) )
.map_err(|e| {error!(e); Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")})? .map_err(|e| {
{ error!(e);
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")); Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")
})? {
return Err(Error::BadRequest(
ErrorKind::InvalidParam,
"Auth check failed",
));
} }
info!("Validation successful."); info!("Validation successful.");
// 7. Persist the event as an outlier. // 7. Persist the event as an outlier.
db.rooms services()
.rooms
.add_pdu_outlier(&incoming_pdu.event_id, &val)?; .add_pdu_outlier(&incoming_pdu.event_id, &val)?;
info!("Added pdu as outlier."); info!("Added pdu as outlier.");
@ -342,22 +391,22 @@ impl Service {
}) })
} }
#[tracing::instrument(skip(incoming_pdu, val, create_event, db, pub_key_map))] #[tracing::instrument(skip(incoming_pdu, val, create_event, pub_key_map))]
async fn upgrade_outlier_to_timeline_pdu( pub async fn upgrade_outlier_to_timeline_pdu(
&self,
incoming_pdu: Arc<PduEvent>, incoming_pdu: Arc<PduEvent>,
val: BTreeMap<String, CanonicalJsonValue>, val: BTreeMap<String, CanonicalJsonValue>,
create_event: &PduEvent, create_event: &PduEvent,
origin: &ServerName, origin: &ServerName,
db: &Database,
room_id: &RoomId, room_id: &RoomId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> Result<Option<Vec<u8>>, String> { ) -> Result<Option<Vec<u8>>, String> {
// Skip the PDU if we already have it as a timeline event // Skip the PDU if we already have it as a timeline event
if let Ok(Some(pduid)) = db.rooms.get_pdu_id(&incoming_pdu.event_id) { if let Ok(Some(pduid)) = services().rooms.get_pdu_id(&incoming_pdu.event_id) {
return Ok(Some(pduid)); return Ok(Some(pduid));
} }
if db if services()
.rooms .rooms
.is_event_soft_failed(&incoming_pdu.event_id) .is_event_soft_failed(&incoming_pdu.event_id)
.map_err(|_| "Failed to ask db for soft fail".to_owned())? .map_err(|_| "Failed to ask db for soft fail".to_owned())?
@ -387,32 +436,32 @@ impl Service {
if incoming_pdu.prev_events.len() == 1 { if incoming_pdu.prev_events.len() == 1 {
let prev_event = &*incoming_pdu.prev_events[0]; let prev_event = &*incoming_pdu.prev_events[0];
let prev_event_sstatehash = db let prev_event_sstatehash = services()
.rooms .rooms
.pdu_shortstatehash(prev_event) .pdu_shortstatehash(prev_event)
.map_err(|_| "Failed talking to db".to_owned())?; .map_err(|_| "Failed talking to db".to_owned())?;
let state = if let Some(shortstatehash) = prev_event_sstatehash { let state = if let Some(shortstatehash) = prev_event_sstatehash {
Some(db.rooms.state_full_ids(shortstatehash).await) Some(services().rooms.state_full_ids(shortstatehash).await)
} else { } else {
None None
}; };
if let Some(Ok(mut state)) = state { if let Some(Ok(mut state)) = state {
info!("Using cached state"); info!("Using cached state");
let prev_pdu = let prev_pdu = services()
db.rooms.get_pdu(prev_event).ok().flatten().ok_or_else(|| { .rooms
.get_pdu(prev_event)
.ok()
.flatten()
.ok_or_else(|| {
"Could not find prev event, but we know the state.".to_owned() "Could not find prev event, but we know the state.".to_owned()
})?; })?;
if let Some(state_key) = &prev_pdu.state_key { if let Some(state_key) = &prev_pdu.state_key {
let shortstatekey = db let shortstatekey = services()
.rooms .rooms
.get_or_create_shortstatekey( .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)
&prev_pdu.kind.to_string().into(),
state_key,
&db.globals,
)
.map_err(|_| "Failed to create shortstatekey.".to_owned())?; .map_err(|_| "Failed to create shortstatekey.".to_owned())?;
state.insert(shortstatekey, Arc::from(prev_event)); state.insert(shortstatekey, Arc::from(prev_event));
@ -427,19 +476,20 @@ impl Service {
let mut okay = true; let mut okay = true;
for prev_eventid in &incoming_pdu.prev_events { for prev_eventid in &incoming_pdu.prev_events {
let prev_event = if let Ok(Some(pdu)) = db.rooms.get_pdu(prev_eventid) { let prev_event = if let Ok(Some(pdu)) = services().rooms.get_pdu(prev_eventid) {
pdu pdu
} else { } else {
okay = false; okay = false;
break; break;
}; };
let sstatehash = if let Ok(Some(s)) = db.rooms.pdu_shortstatehash(prev_eventid) { let sstatehash =
s if let Ok(Some(s)) = services().rooms.pdu_shortstatehash(prev_eventid) {
} else { s
okay = false; } else {
break; okay = false;
}; break;
};
extremity_sstatehashes.insert(sstatehash, prev_event); extremity_sstatehashes.insert(sstatehash, prev_event);
} }
@ -449,19 +499,18 @@ impl Service {
let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len());
for (sstatehash, prev_event) in extremity_sstatehashes { for (sstatehash, prev_event) in extremity_sstatehashes {
let mut leaf_state: BTreeMap<_, _> = db let mut leaf_state: BTreeMap<_, _> = services()
.rooms .rooms
.state_full_ids(sstatehash) .state_full_ids(sstatehash)
.await .await
.map_err(|_| "Failed to ask db for room state.".to_owned())?; .map_err(|_| "Failed to ask db for room state.".to_owned())?;
if let Some(state_key) = &prev_event.state_key { if let Some(state_key) = &prev_event.state_key {
let shortstatekey = db let shortstatekey = services()
.rooms .rooms
.get_or_create_shortstatekey( .get_or_create_shortstatekey(
&prev_event.kind.to_string().into(), &prev_event.kind.to_string().into(),
state_key, state_key,
&db.globals,
) )
.map_err(|_| "Failed to create shortstatekey.".to_owned())?; .map_err(|_| "Failed to create shortstatekey.".to_owned())?;
leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id));
@ -472,7 +521,7 @@ impl Service {
let mut starting_events = Vec::with_capacity(leaf_state.len()); let mut starting_events = Vec::with_capacity(leaf_state.len());
for (k, id) in leaf_state { for (k, id) in leaf_state {
if let Ok((ty, st_key)) = db.rooms.get_statekey_from_short(k) { if let Ok((ty, st_key)) = services().rooms.get_statekey_from_short(k) {
// FIXME: Undo .to_string().into() when StateMap // FIXME: Undo .to_string().into() when StateMap
// is updated to use StateEventType // is updated to use StateEventType
state.insert((ty.to_string().into(), st_key), id.clone()); state.insert((ty.to_string().into(), st_key), id.clone());
@ -483,7 +532,10 @@ impl Service {
} }
auth_chain_sets.push( auth_chain_sets.push(
get_auth_chain(room_id, starting_events, db) services()
.rooms
.auth_chain
.get_auth_chain(room_id, starting_events, services())
.await .await
.map_err(|_| "Failed to load auth chain.".to_owned())? .map_err(|_| "Failed to load auth chain.".to_owned())?
.collect(), .collect(),
@ -492,15 +544,16 @@ impl Service {
fork_states.push(state); fork_states.push(state);
} }
let lock = db.globals.stateres_mutex.lock(); let lock = services().globals.stateres_mutex.lock();
let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { let result =
let res = db.rooms.get_pdu(id); state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| {
if let Err(e) = &res { let res = services().rooms.get_pdu(id);
error!("LOOK AT ME Failed to fetch event: {}", e); if let Err(e) = &res {
} error!("LOOK AT ME Failed to fetch event: {}", e);
res.ok().flatten() }
}); res.ok().flatten()
});
drop(lock); drop(lock);
state_at_incoming_event = match result { state_at_incoming_event = match result {
@ -508,14 +561,15 @@ impl Service {
new_state new_state
.into_iter() .into_iter()
.map(|((event_type, state_key), event_id)| { .map(|((event_type, state_key), event_id)| {
let shortstatekey = db let shortstatekey = services()
.rooms .rooms
.get_or_create_shortstatekey( .get_or_create_shortstatekey(
&event_type.to_string().into(), &event_type.to_string().into(),
&state_key, &state_key,
&db.globals,
) )
.map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?; .map_err(|_| {
"Failed to get_or_create_shortstatekey".to_owned()
})?;
Ok((shortstatekey, event_id)) Ok((shortstatekey, event_id))
}) })
.collect::<Result<_, String>>()?, .collect::<Result<_, String>>()?,
@ -532,10 +586,9 @@ impl Service {
info!("Calling /state_ids"); info!("Calling /state_ids");
// Call /state_ids to find out what the state at this pdu is. We trust the server's // Call /state_ids to find out what the state at this pdu is. We trust the server's
// response to some extend, but we still do a lot of checks on the events // response to some extend, but we still do a lot of checks on the events
match db match services()
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals,
origin, origin,
get_room_state_ids::v1::Request { get_room_state_ids::v1::Request {
room_id, room_id,
@ -546,18 +599,18 @@ impl Service {
{ {
Ok(res) => { Ok(res) => {
info!("Fetching state events at event."); info!("Fetching state events at event.");
let state_vec = fetch_and_handle_outliers( let state_vec = self
db, .fetch_and_handle_outliers(
origin, origin,
&res.pdu_ids &res.pdu_ids
.iter() .iter()
.map(|x| Arc::from(&**x)) .map(|x| Arc::from(&**x))
.collect::<Vec<_>>(), .collect::<Vec<_>>(),
create_event, create_event,
room_id, room_id,
pub_key_map, pub_key_map,
) )
.await; .await;
let mut state: BTreeMap<_, Arc<EventId>> = BTreeMap::new(); let mut state: BTreeMap<_, Arc<EventId>> = BTreeMap::new();
for (pdu, _) in state_vec { for (pdu, _) in state_vec {
@ -566,13 +619,9 @@ impl Service {
.clone() .clone()
.ok_or_else(|| "Found non-state pdu in state events.".to_owned())?; .ok_or_else(|| "Found non-state pdu in state events.".to_owned())?;
let shortstatekey = db let shortstatekey = services()
.rooms .rooms
.get_or_create_shortstatekey( .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)
&pdu.kind.to_string().into(),
&state_key,
&db.globals,
)
.map_err(|_| "Failed to create shortstatekey.".to_owned())?; .map_err(|_| "Failed to create shortstatekey.".to_owned())?;
match state.entry(shortstatekey) { match state.entry(shortstatekey) {
@ -587,7 +636,7 @@ impl Service {
} }
// The original create event must still be in the state // The original create event must still be in the state
let create_shortstatekey = db let create_shortstatekey = services()
.rooms .rooms
.get_shortstatekey(&StateEventType::RoomCreate, "") .get_shortstatekey(&StateEventType::RoomCreate, "")
.map_err(|_| "Failed to talk to db.")? .map_err(|_| "Failed to talk to db.")?
@ -618,12 +667,13 @@ impl Service {
&incoming_pdu, &incoming_pdu,
None::<PduEvent>, // TODO: third party invite None::<PduEvent>, // TODO: third party invite
|k, s| { |k, s| {
db.rooms services()
.rooms
.get_shortstatekey(&k.to_string().into(), s) .get_shortstatekey(&k.to_string().into(), s)
.ok() .ok()
.flatten() .flatten()
.and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey))
.and_then(|event_id| db.rooms.get_pdu(event_id).ok().flatten()) .and_then(|event_id| services().rooms.get_pdu(event_id).ok().flatten())
}, },
) )
.map_err(|_e| "Auth check failed.".to_owned())?; .map_err(|_e| "Auth check failed.".to_owned())?;
@ -636,7 +686,8 @@ impl Service {
// We start looking at current room state now, so lets lock the room // We start looking at current room state now, so lets lock the room
let mutex_state = Arc::clone( let mutex_state = Arc::clone(
db.globals services()
.globals
.roomid_mutex_state .roomid_mutex_state
.write() .write()
.unwrap() .unwrap()
@ -648,7 +699,7 @@ impl Service {
// Now we calculate the set of extremities this room has after the incoming event has been // Now we calculate the set of extremities this room has after the incoming event has been
// applied. We start with the previous extremities (aka leaves) // applied. We start with the previous extremities (aka leaves)
info!("Calculating extremities"); info!("Calculating extremities");
let mut extremities = db let mut extremities = services()
.rooms .rooms
.get_pdu_leaves(room_id) .get_pdu_leaves(room_id)
.map_err(|_| "Failed to load room leaves".to_owned())?; .map_err(|_| "Failed to load room leaves".to_owned())?;
@ -661,14 +712,16 @@ impl Service {
} }
// Only keep those extremities were not referenced yet // Only keep those extremities were not referenced yet
extremities.retain(|id| !matches!(db.rooms.is_event_referenced(room_id, id), Ok(true))); extremities
.retain(|id| !matches!(services().rooms.is_event_referenced(room_id, id), Ok(true)));
info!("Compressing state at event"); info!("Compressing state at event");
let state_ids_compressed = state_at_incoming_event let state_ids_compressed = state_at_incoming_event
.iter() .iter()
.map(|(shortstatekey, id)| { .map(|(shortstatekey, id)| {
db.rooms services()
.compress_state_event(*shortstatekey, id, &db.globals) .rooms
.compress_state_event(*shortstatekey, id)
.map_err(|_| "Failed to compress_state_event".to_owned()) .map_err(|_| "Failed to compress_state_event".to_owned())
}) })
.collect::<Result<_, _>>()?; .collect::<Result<_, _>>()?;
@ -676,7 +729,7 @@ impl Service {
// 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it // 13. Check if the event passes auth based on the "current state" of the room, if not "soft fail" it
info!("Starting soft fail auth check"); info!("Starting soft fail auth check");
let auth_events = db let auth_events = services()
.rooms .rooms
.get_auth_events( .get_auth_events(
room_id, room_id,
@ -696,11 +749,10 @@ impl Service {
.map_err(|_e| "Auth check failed.".to_owned())?; .map_err(|_e| "Auth check failed.".to_owned())?;
if soft_fail { if soft_fail {
append_incoming_pdu( self.append_incoming_pdu(
db,
&incoming_pdu, &incoming_pdu,
val, val,
extremities.iter().map(Deref::deref), extremities.iter().map(std::ops::Deref::deref),
state_ids_compressed, state_ids_compressed,
soft_fail, soft_fail,
&state_lock, &state_lock,
@ -712,7 +764,8 @@ impl Service {
// Soft fail, we keep the event as an outlier but don't add it to the timeline // Soft fail, we keep the event as an outlier but don't add it to the timeline
warn!("Event was soft failed: {:?}", incoming_pdu); warn!("Event was soft failed: {:?}", incoming_pdu);
db.rooms services()
.rooms
.mark_event_soft_failed(&incoming_pdu.event_id) .mark_event_soft_failed(&incoming_pdu.event_id)
.map_err(|_| "Failed to set soft failed flag".to_owned())?; .map_err(|_| "Failed to set soft failed flag".to_owned())?;
return Err("Event has been soft failed".into()); return Err("Event has been soft failed".into());
@ -720,13 +773,13 @@ impl Service {
if incoming_pdu.state_key.is_some() { if incoming_pdu.state_key.is_some() {
info!("Loading current room state ids"); info!("Loading current room state ids");
let current_sstatehash = db let current_sstatehash = services()
.rooms .rooms
.current_shortstatehash(room_id) .current_shortstatehash(room_id)
.map_err(|_| "Failed to load current state hash.".to_owned())? .map_err(|_| "Failed to load current state hash.".to_owned())?
.expect("every room has state"); .expect("every room has state");
let current_state_ids = db let current_state_ids = services()
.rooms .rooms
.state_full_ids(current_sstatehash) .state_full_ids(current_sstatehash)
.await .await
@ -737,14 +790,14 @@ impl Service {
info!("Loading extremities"); info!("Loading extremities");
for id in dbg!(&extremities) { for id in dbg!(&extremities) {
match db match services()
.rooms .rooms
.get_pdu(id) .get_pdu(id)
.map_err(|_| "Failed to ask db for pdu.".to_owned())? .map_err(|_| "Failed to ask db for pdu.".to_owned())?
{ {
Some(leaf_pdu) => { Some(leaf_pdu) => {
extremity_sstatehashes.insert( extremity_sstatehashes.insert(
db.rooms services()
.pdu_shortstatehash(&leaf_pdu.event_id) .pdu_shortstatehash(&leaf_pdu.event_id)
.map_err(|_| "Failed to ask db for pdu state hash.".to_owned())? .map_err(|_| "Failed to ask db for pdu state hash.".to_owned())?
.ok_or_else(|| { .ok_or_else(|| {
@ -777,13 +830,9 @@ impl Service {
// We also add state after incoming event to the fork states // We also add state after incoming event to the fork states
let mut state_after = state_at_incoming_event.clone(); let mut state_after = state_at_incoming_event.clone();
if let Some(state_key) = &incoming_pdu.state_key { if let Some(state_key) = &incoming_pdu.state_key {
let shortstatekey = db let shortstatekey = services()
.rooms .rooms
.get_or_create_shortstatekey( .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)
&incoming_pdu.kind.to_string().into(),
state_key,
&db.globals,
)
.map_err(|_| "Failed to create shortstatekey.".to_owned())?; .map_err(|_| "Failed to create shortstatekey.".to_owned())?;
state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id));
@ -801,8 +850,9 @@ impl Service {
fork_states[0] fork_states[0]
.iter() .iter()
.map(|(k, id)| { .map(|(k, id)| {
db.rooms services()
.compress_state_event(*k, id, &db.globals) .rooms
.compress_state_event(*k, id)
.map_err(|_| "Failed to compress_state_event.".to_owned()) .map_err(|_| "Failed to compress_state_event.".to_owned())
}) })
.collect::<Result<_, _>>()? .collect::<Result<_, _>>()?
@ -814,14 +864,16 @@ impl Service {
let mut auth_chain_sets = Vec::new(); let mut auth_chain_sets = Vec::new();
for state in &fork_states { for state in &fork_states {
auth_chain_sets.push( auth_chain_sets.push(
get_auth_chain( services()
room_id, .rooms
state.iter().map(|(_, id)| id.clone()).collect(), .auth_chain
db, .get_auth_chain(
) room_id,
.await state.iter().map(|(_, id)| id.clone()).collect(),
.map_err(|_| "Failed to load auth chain.".to_owned())? )
.collect(), .await
.map_err(|_| "Failed to load auth chain.".to_owned())?
.collect(),
); );
} }
@ -832,7 +884,8 @@ impl Service {
.map(|map| { .map(|map| {
map.into_iter() map.into_iter()
.filter_map(|(k, id)| { .filter_map(|(k, id)| {
db.rooms services()
.rooms
.get_statekey_from_short(k) .get_statekey_from_short(k)
// FIXME: Undo .to_string().into() when StateMap // FIXME: Undo .to_string().into() when StateMap
// is updated to use StateEventType // is updated to use StateEventType
@ -846,13 +899,13 @@ impl Service {
info!("Resolving state"); info!("Resolving state");
let lock = db.globals.stateres_mutex.lock(); let lock = services().globals.stateres_mutex.lock();
let state = match state_res::resolve( let state = match state_res::resolve(
room_version_id, room_version_id,
&fork_states, &fork_states,
auth_chain_sets, auth_chain_sets,
|id| { |id| {
let res = db.rooms.get_pdu(id); let res = services().rooms.get_pdu(id);
if let Err(e) = &res { if let Err(e) = &res {
error!("LOOK AT ME Failed to fetch event: {}", e); error!("LOOK AT ME Failed to fetch event: {}", e);
} }
@ -872,16 +925,13 @@ impl Service {
state state
.into_iter() .into_iter()
.map(|((event_type, state_key), event_id)| { .map(|((event_type, state_key), event_id)| {
let shortstatekey = db let shortstatekey = services()
.rooms .rooms
.get_or_create_shortstatekey( .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)
&event_type.to_string().into(),
&state_key,
&db.globals,
)
.map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?; .map_err(|_| "Failed to get_or_create_shortstatekey".to_owned())?;
db.rooms services()
.compress_state_event(shortstatekey, &event_id, &db.globals) .rooms
.compress_state_event(shortstatekey, &event_id)
.map_err(|_| "Failed to compress state event".to_owned()) .map_err(|_| "Failed to compress state event".to_owned())
}) })
.collect::<Result<_, _>>()? .collect::<Result<_, _>>()?
@ -890,8 +940,9 @@ impl Service {
// Set the new room state to the resolved state // Set the new room state to the resolved state
if update_state { if update_state {
info!("Forcing new room state"); info!("Forcing new room state");
db.rooms services()
.force_state(room_id, new_room_state, db) .rooms
.force_state(room_id, new_room_state)
.map_err(|_| "Failed to set new room state.".to_owned())?; .map_err(|_| "Failed to set new room state.".to_owned())?;
} }
} }
@ -903,19 +954,19 @@ impl Service {
// We use the `state_at_event` instead of `state_after` so we accurately // We use the `state_at_event` instead of `state_after` so we accurately
// represent the state for this event. // represent the state for this event.
let pdu_id = append_incoming_pdu( let pdu_id = self
db, .append_incoming_pdu(
&incoming_pdu, &incoming_pdu,
val, val,
extremities.iter().map(Deref::deref), extremities.iter().map(std::ops::Deref::deref),
state_ids_compressed, state_ids_compressed,
soft_fail, soft_fail,
&state_lock, &state_lock,
) )
.map_err(|e| { .map_err(|e| {
warn!("Failed to add pdu to db: {}", e); warn!("Failed to add pdu to db: {}", e);
"Failed to add pdu to db.".to_owned() "Failed to add pdu to db.".to_owned()
})?; })?;
info!("Appended incoming pdu"); info!("Appended incoming pdu");
@ -935,15 +986,22 @@ impl Service {
/// d. TODO: Ask other servers over federation? /// d. TODO: Ask other servers over federation?
#[tracing::instrument(skip_all)] #[tracing::instrument(skip_all)]
pub(crate) fn fetch_and_handle_outliers<'a>( pub(crate) fn fetch_and_handle_outliers<'a>(
db: &'a Database, &self,
origin: &'a ServerName, origin: &'a ServerName,
events: &'a [Arc<EventId>], events: &'a [Arc<EventId>],
create_event: &'a PduEvent, create_event: &'a PduEvent,
room_id: &'a RoomId, room_id: &'a RoomId,
pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>, pub_key_map: &'a RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
) -> AsyncRecursiveType<'a, Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)>> { ) -> AsyncRecursiveType<'a, Vec<(Arc<PduEvent>, Option<BTreeMap<String, CanonicalJsonValue>>)>>
{
Box::pin(async move { Box::pin(async move {
let back_off = |id| match db.globals.bad_event_ratelimiter.write().unwrap().entry(id) { let back_off = |id| match services()
.globals
.bad_event_ratelimiter
.write()
.unwrap()
.entry(id)
{
hash_map::Entry::Vacant(e) => { hash_map::Entry::Vacant(e) => {
e.insert((Instant::now(), 1)); e.insert((Instant::now(), 1));
} }
@ -952,10 +1010,16 @@ impl Service {
let mut pdus = vec![]; let mut pdus = vec![];
for id in events { for id in events {
if let Some((time, tries)) = db.globals.bad_event_ratelimiter.read().unwrap().get(&**id) if let Some((time, tries)) = services()
.globals
.bad_event_ratelimiter
.read()
.unwrap()
.get(&**id)
{ {
// Exponential backoff // Exponential backoff
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); let mut min_elapsed_duration =
Duration::from_secs(5 * 60) * (*tries) * (*tries);
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
min_elapsed_duration = Duration::from_secs(60 * 60 * 24); min_elapsed_duration = Duration::from_secs(60 * 60 * 24);
} }
@ -969,7 +1033,7 @@ impl Service {
// a. Look in the main timeline (pduid_pdu tree) // a. Look in the main timeline (pduid_pdu tree)
// b. Look at outlier pdu tree // b. Look at outlier pdu tree
// (get_pdu_json checks both) // (get_pdu_json checks both)
if let Ok(Some(local_pdu)) = db.rooms.get_pdu(id) { if let Ok(Some(local_pdu)) = services().rooms.get_pdu(id) {
trace!("Found {} in db", id); trace!("Found {} in db", id);
pdus.push((local_pdu, None)); pdus.push((local_pdu, None));
continue; continue;
@ -992,16 +1056,15 @@ impl Service {
tokio::task::yield_now().await; tokio::task::yield_now().await;
} }
if let Ok(Some(_)) = db.rooms.get_pdu(&next_id) { if let Ok(Some(_)) = services().rooms.get_pdu(&next_id) {
trace!("Found {} in db", id); trace!("Found {} in db", id);
continue; continue;
} }
info!("Fetching {} over federation.", next_id); info!("Fetching {} over federation.", next_id);
match db match services()
.sending .sending
.send_federation_request( .send_federation_request(
&db.globals,
origin, origin,
get_event::v1::Request { event_id: &next_id }, get_event::v1::Request { event_id: &next_id },
) )
@ -1010,7 +1073,7 @@ impl Service {
Ok(res) => { Ok(res) => {
info!("Got {} over federation", next_id); info!("Got {} over federation", next_id);
let (calculated_event_id, value) = let (calculated_event_id, value) =
match crate::pdu::gen_event_id_canonical_json(&res.pdu, &db) { match pdu::gen_event_id_canonical_json(&res.pdu) {
Ok(t) => t, Ok(t) => t,
Err(_) => { Err(_) => {
back_off((*next_id).to_owned()); back_off((*next_id).to_owned());
@ -1051,16 +1114,16 @@ impl Service {
} }
for (next_id, value) in events_in_reverse_order.iter().rev() { for (next_id, value) in events_in_reverse_order.iter().rev() {
match handle_outlier_pdu( match self
origin, .handle_outlier_pdu(
create_event, origin,
next_id, create_event,
room_id, next_id,
value.clone(), room_id,
db, value.clone(),
pub_key_map, pub_key_map,
) )
.await .await
{ {
Ok((pdu, json)) => { Ok((pdu, json)) => {
if next_id == id { if next_id == id {
@ -1078,9 +1141,14 @@ impl Service {
}) })
} }
async fn fetch_unknown_prev_events(
&self,
fn fetch_unknown_prev_events(initial_set: Vec<Arc<EventId>>) -> Vec<Arc<EventId>> { origin: &ServerName,
create_event: &PduEvent,
room_id: &RoomId,
pub_key_map: &RwLock<BTreeMap<String, BTreeMap<String, Base64>>>,
initial_set: Vec<Arc<EventId>>,
) -> Vec<(Arc<EventId>, HashMap<Arc<EventId>, (Arc<PduEvent>, BTreeMap<String, CanonicalJsonValue>)>)> {
let mut graph: HashMap<Arc<EventId>, _> = HashMap::new(); let mut graph: HashMap<Arc<EventId>, _> = HashMap::new();
let mut eventid_info = HashMap::new(); let mut eventid_info = HashMap::new();
let mut todo_outlier_stack: Vec<Arc<EventId>> = initial_set; let mut todo_outlier_stack: Vec<Arc<EventId>> = initial_set;
@ -1088,16 +1156,16 @@ impl Service {
let mut amount = 0; let mut amount = 0;
while let Some(prev_event_id) = todo_outlier_stack.pop() { while let Some(prev_event_id) = todo_outlier_stack.pop() {
if let Some((pdu, json_opt)) = fetch_and_handle_outliers( if let Some((pdu, json_opt)) = self
db, .fetch_and_handle_outliers(
origin, origin,
&[prev_event_id.clone()], &[prev_event_id.clone()],
&create_event, &create_event,
room_id, room_id,
pub_key_map, pub_key_map,
) )
.await .await
.pop() .pop()
{ {
if amount > 100 { if amount > 100 {
// Max limit reached // Max limit reached
@ -1106,9 +1174,13 @@ impl Service {
continue; continue;
} }
if let Some(json) = if let Some(json) = json_opt.or_else(|| {
json_opt.or_else(|| db.rooms.get_outlier_pdu_json(&prev_event_id).ok().flatten()) services()
{ .rooms
.get_outlier_pdu_json(&prev_event_id)
.ok()
.flatten()
}) {
if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts {
amount += 1; amount += 1;
for prev_prev in &pdu.prev_events { for prev_prev in &pdu.prev_events {
@ -1153,6 +1225,6 @@ impl Service {
}) })
.map_err(|_| "Error sorting prev events".to_owned())?; .map_err(|_| "Error sorting prev events".to_owned())?;
sorted (sorted, eventid_info)
} }
} }

View file

@ -1,3 +1,5 @@
use ruma::{RoomId, DeviceId, UserId};
pub trait Data { pub trait Data {
fn lazy_load_was_sent_before( fn lazy_load_was_sent_before(
&self, &self,

View file

@ -1,5 +1,8 @@
mod data; mod data;
use std::collections::HashSet;
pub use data::Data; pub use data::Data;
use ruma::{DeviceId, UserId, RoomId};
use crate::service::*; use crate::service::*;
@ -47,7 +50,7 @@ impl Service<_> {
room_id: &RoomId, room_id: &RoomId,
since: u64, since: u64,
) -> Result<()> { ) -> Result<()> {
self.db.lazy_load_confirm_delivery(user_d, device_id, room_id, since) self.db.lazy_load_confirm_delivery(user_id, device_id, room_id, since)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
@ -57,6 +60,6 @@ impl Service<_> {
device_id: &DeviceId, device_id: &DeviceId,
room_id: &RoomId, room_id: &RoomId,
) -> Result<()> { ) -> Result<()> {
self.db.lazy_load_reset(user_id, device_id, room_id); self.db.lazy_load_reset(user_id, device_id, room_id)
} }
} }

View file

@ -1,3 +1,5 @@
use ruma::RoomId;
pub trait Data { pub trait Data {
fn exists(&self, room_id: &RoomId) -> Result<bool>; fn exists(&self, room_id: &RoomId) -> Result<bool>;
} }

View file

@ -1,5 +1,6 @@
mod data; mod data;
pub use data::Data; pub use data::Data;
use ruma::RoomId;
use crate::service::*; use crate::service::*;

View file

@ -1,216 +1,37 @@
mod edus; pub mod alias;
pub mod auth_chain;
pub use edus::RoomEdus; pub mod directory;
pub mod edus;
use crate::{ pub mod event_handler;
pdu::{EventHash, PduBuilder}, pub mod lazy_loading;
utils, Database, Error, PduEvent, Result, pub mod metadata;
}; pub mod outlier;
use lru_cache::LruCache; pub mod pdu_metadata;
use regex::Regex; pub mod search;
use ring::digest; pub mod short;
use ruma::{ pub mod state;
api::{client::error::ErrorKind, federation}, pub mod state_accessor;
events::{ pub mod state_cache;
direct::DirectEvent, pub mod state_compressor;
ignored_user_list::IgnoredUserListEvent, pub mod timeline;
push_rules::PushRulesEvent, pub mod user;
room::{
create::RoomCreateEventContent,
member::{MembershipState, RoomMemberEventContent},
power_levels::RoomPowerLevelsEventContent,
},
tag::TagEvent,
AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType,
RoomAccountDataEventType, RoomEventType, StateEventType,
},
push::{Action, Ruleset, Tweak},
serde::{CanonicalJsonObject, CanonicalJsonValue, Raw},
state_res::{self, RoomVersion, StateMap},
uint, DeviceId, EventId, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId,
};
use serde::Deserialize;
use serde_json::value::to_raw_value;
use std::{
borrow::Cow,
collections::{hash_map, BTreeMap, HashMap, HashSet},
fmt::Debug,
iter,
mem::size_of,
sync::{Arc, Mutex, RwLock},
};
use tokio::sync::MutexGuard;
use tracing::{error, warn};
use super::{abstraction::Tree, pusher};
/// The unique identifier of each state group.
///
/// This is created when a state group is added to the database by
/// hashing the entire state.
pub type StateHashId = Vec<u8>;
pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
pub struct Rooms {
pub edus: RoomEdus,
pub(super) pduid_pdu: Arc<dyn Tree>, // PduId = ShortRoomId + Count
pub(super) eventid_pduid: Arc<dyn Tree>,
pub(super) roomid_pduleaves: Arc<dyn Tree>,
pub(super) alias_roomid: Arc<dyn Tree>,
pub(super) aliasid_alias: Arc<dyn Tree>, // AliasId = RoomId + Count
pub(super) publicroomids: Arc<dyn Tree>,
pub(super) tokenids: Arc<dyn Tree>, // TokenId = ShortRoomId + Token + PduIdCount
/// Participating servers in a room.
pub(super) roomserverids: Arc<dyn Tree>, // RoomServerId = RoomId + ServerName
pub(super) serverroomids: Arc<dyn Tree>, // ServerRoomId = ServerName + RoomId
pub(super) userroomid_joined: Arc<dyn Tree>,
pub(super) roomuserid_joined: Arc<dyn Tree>,
pub(super) roomid_joinedcount: Arc<dyn Tree>,
pub(super) roomid_invitedcount: Arc<dyn Tree>,
pub(super) roomuseroncejoinedids: Arc<dyn Tree>,
pub(super) userroomid_invitestate: Arc<dyn Tree>, // InviteState = Vec<Raw<Pdu>>
pub(super) roomuserid_invitecount: Arc<dyn Tree>, // InviteCount = Count
pub(super) userroomid_leftstate: Arc<dyn Tree>,
pub(super) roomuserid_leftcount: Arc<dyn Tree>,
pub(super) disabledroomids: Arc<dyn Tree>, // Rooms where incoming federation handling is disabled
pub(super) lazyloadedids: Arc<dyn Tree>, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId
pub(super) userroomid_notificationcount: Arc<dyn Tree>, // NotifyCount = u64
pub(super) userroomid_highlightcount: Arc<dyn Tree>, // HightlightCount = u64
/// Remember the current state hash of a room.
pub(super) roomid_shortstatehash: Arc<dyn Tree>,
pub(super) roomsynctoken_shortstatehash: Arc<dyn Tree>,
/// Remember the state hash at events in the past.
pub(super) shorteventid_shortstatehash: Arc<dyn Tree>,
/// StateKey = EventType + StateKey, ShortStateKey = Count
pub(super) statekey_shortstatekey: Arc<dyn Tree>,
pub(super) shortstatekey_statekey: Arc<dyn Tree>,
pub(super) roomid_shortroomid: Arc<dyn Tree>,
pub(super) shorteventid_eventid: Arc<dyn Tree>,
pub(super) eventid_shorteventid: Arc<dyn Tree>,
pub(super) statehash_shortstatehash: Arc<dyn Tree>,
pub(super) shortstatehash_statediff: Arc<dyn Tree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--)
pub(super) shorteventid_authchain: Arc<dyn Tree>,
/// RoomId + EventId -> outlier PDU.
/// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn.
pub(super) eventid_outlierpdu: Arc<dyn Tree>,
pub(super) softfailedeventids: Arc<dyn Tree>,
/// RoomId + EventId -> Parent PDU EventId.
pub(super) referencedevents: Arc<dyn Tree>,
pub(super) pdu_cache: Mutex<LruCache<Box<EventId>, Arc<PduEvent>>>,
pub(super) shorteventid_cache: Mutex<LruCache<u64, Arc<EventId>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<HashSet<u64>>>>,
pub(super) eventidshort_cache: Mutex<LruCache<Box<EventId>, u64>>,
pub(super) statekeyshort_cache: Mutex<LruCache<(StateEventType, String), u64>>,
pub(super) shortstatekey_cache: Mutex<LruCache<u64, (StateEventType, String)>>,
pub(super) our_real_users_cache: RwLock<HashMap<Box<RoomId>, Arc<HashSet<Box<UserId>>>>>,
pub(super) appservice_in_room_cache: RwLock<HashMap<Box<RoomId>, HashMap<String, bool>>>,
pub(super) lazy_load_waiting:
Mutex<HashMap<(Box<UserId>, Box<DeviceId>, Box<RoomId>, u64), HashSet<Box<UserId>>>>,
pub(super) stateinfo_cache: Mutex<
LruCache<
u64,
Vec<(
u64, // sstatehash
HashSet<CompressedStateEvent>, // full state
HashSet<CompressedStateEvent>, // added
HashSet<CompressedStateEvent>, // removed
)>,
>,
>,
pub(super) lasttimelinecount_cache: Mutex<HashMap<Box<RoomId>, u64>>,
}
impl Rooms {
/// Returns true if a given room version is supported
#[tracing::instrument(skip(self, db))]
pub fn is_supported_version(&self, db: &Database, room_version: &RoomVersionId) -> bool {
db.globals.supported_room_versions().contains(room_version)
}
/// This fetches auth events from the current state.
#[tracing::instrument(skip(self))]
pub fn get_auth_events(
&self,
room_id: &RoomId,
kind: &RoomEventType,
sender: &UserId,
state_key: Option<&str>,
content: &serde_json::value::RawValue,
) -> Result<StateMap<Arc<PduEvent>>> {
let shortstatehash =
if let Some(current_shortstatehash) = self.current_shortstatehash(room_id)? {
current_shortstatehash
} else {
return Ok(HashMap::new());
};
let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content)
.expect("content is a valid JSON object");
let mut sauthevents = auth_events
.into_iter()
.filter_map(|(event_type, state_key)| {
self.get_shortstatekey(&event_type.to_string().into(), &state_key)
.ok()
.flatten()
.map(|s| (s, (event_type, state_key)))
})
.collect::<HashMap<_, _>>();
let full_state = self
.load_shortstatehash_info(shortstatehash)?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.into_iter()
.filter_map(|compressed| self.parse_compressed_state_event(compressed).ok())
.filter_map(|(shortstatekey, event_id)| {
sauthevents.remove(&shortstatekey).map(|k| (k, event_id))
})
.filter_map(|(k, event_id)| self.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu)))
.collect())
}
/// Generate a new StateHash.
///
/// A unique hash made from hashing all PDU ids of the state joined with 0xff.
fn calculate_hash(&self, bytes_list: &[&[u8]]) -> StateHashId {
// We only hash the pdu's event ids, not the whole pdu
let bytes = bytes_list.join(&0xff);
let hash = digest::digest(&digest::SHA256, &bytes);
hash.as_ref().into()
}
#[tracing::instrument(skip(self))]
pub fn iter_ids(&self) -> impl Iterator<Item = Result<Box<RoomId>>> + '_ {
self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Room ID in publicroomids is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
})
}
pub fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
}
pub struct Service<D> {
pub alias: alias::Service<D>,
pub auth_chain: auth_chain::Service<D>,
pub directory: directory::Service<D>,
pub edus: edus::Service<D>,
pub event_handler: event_handler::Service,
pub lazy_loading: lazy_loading::Service<D>,
pub metadata: metadata::Service<D>,
pub outlier: outlier::Service<D>,
pub pdu_metadata: pdu_metadata::Service<D>,
pub search: search::Service<D>,
pub short: short::Service<D>,
pub state: state::Service<D>,
pub state_accessor: state_accessor::Service<D>,
pub state_cache: state_cache::Service<D>,
pub state_compressor: state_compressor::Service<D>,
pub timeline: timeline::Service<D>,
pub user: user::Service<D>,
} }

View file

@ -1,3 +1,7 @@
use ruma::{EventId, signatures::CanonicalJsonObject};
use crate::PduEvent;
pub trait Data { pub trait Data {
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>; fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>>;
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>; fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>>;

View file

@ -1,7 +1,8 @@
mod data; mod data;
pub use data::Data; pub use data::Data;
use ruma::{EventId, signatures::CanonicalJsonObject};
use crate::service::*; use crate::{service::*, PduEvent};
pub struct Service<D: Data> { pub struct Service<D: Data> {
db: D, db: D,

View file

@ -1,3 +1,7 @@
use std::sync::Arc;
use ruma::{EventId, RoomId};
pub trait Data { pub trait Data {
fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()>; fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()>;
fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool>; fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool>;

View file

@ -1,5 +1,8 @@
mod data; mod data;
use std::sync::Arc;
pub use data::Data; pub use data::Data;
use ruma::{RoomId, EventId};
use crate::service::*; use crate::service::*;

View file

@ -1,7 +1,9 @@
pub trait Data { use ruma::RoomId;
pub fn index_pdu<'a>(&self, room_id: &RoomId, pdu_id: u64, message_body: String) -> Result<()>;
pub fn search_pdus<'a>( pub trait Data {
fn index_pdu<'a>(&self, room_id: &RoomId, pdu_id: u64, message_body: String) -> Result<()>;
fn search_pdus<'a>(
&'a self, &'a self,
room_id: &RoomId, room_id: &RoomId,
search_string: &str, search_string: &str,

View file

@ -1,7 +1,6 @@
mod data; mod data;
pub use data::Data; pub use data::Data;
use ruma::RoomId;
use crate::service::*;
pub struct Service<D: Data> { pub struct Service<D: Data> {
db: D, db: D,

View file

@ -1,7 +1,10 @@
mod data; mod data;
pub use data::Data; use std::sync::Arc;
use crate::service::*; pub use data::Data;
use ruma::{EventId, events::StateEventType};
use crate::{service::*, Error, utils};
pub struct Service<D: Data> { pub struct Service<D: Data> {
db: D, db: D,
@ -188,7 +191,6 @@ impl Service<_> {
fn get_or_create_shortstatehash( fn get_or_create_shortstatehash(
&self, &self,
state_hash: &StateHashId, state_hash: &StateHashId,
globals: &super::globals::Globals,
) -> Result<(u64, bool)> { ) -> Result<(u64, bool)> {
Ok(match self.statehash_shortstatehash.get(state_hash)? { Ok(match self.statehash_shortstatehash.get(state_hash)? {
Some(shortstatehash) => ( Some(shortstatehash) => (

Some files were not shown because too many files have changed in this diff Show more