diff --git a/rustfmt.toml b/rustfmt.toml new file mode 100644 index 00000000..3216a9aa --- /dev/null +++ b/rustfmt.toml @@ -0,0 +1,27 @@ +edition = "2021" + +condense_wildcard_suffixes = true +format_code_in_doc_comments = true +format_macro_bodies = true +format_macro_matchers = true +format_strings = true +hex_literal_case = "Upper" +max_width = 120 +tab_spaces = 4 +array_width = 80 +comment_width = 80 +wrap_comments = true +fn_params_layout = "Compressed" +fn_call_width = 80 +fn_single_line = true +hard_tabs = true +match_block_trailing_comma = true +imports_granularity = "Crate" +normalize_comments = false +reorder_impl_items = true +reorder_imports = true +group_imports = "StdExternalCrate" +newline_style = "Unix" +use_field_init_shorthand = true +use_small_heuristics = "Off" +use_try_shorthand = true \ No newline at end of file diff --git a/src/api/appservice_server.rs b/src/api/appservice_server.rs index cc50be19..73c4d12e 100644 --- a/src/api/appservice_server.rs +++ b/src/api/appservice_server.rs @@ -1,114 +1,94 @@ -use crate::{services, utils, Error, Result}; -use bytes::BytesMut; -use ruma::api::{ - appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, -}; use std::{fmt::Debug, mem, time::Duration}; + +use bytes::BytesMut; +use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken}; use tracing::warn; +use crate::{services, utils, Error, Result}; + /// Sends a request to an appservice /// -/// Only returns None if there is no url specified in the appservice registration file -pub(crate) async fn send_request( - registration: Registration, - request: T, -) -> Option> +/// Only returns None if there is no url specified in the appservice +/// registration file +pub(crate) async fn send_request(registration: Registration, request: T) -> Option> where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug, { - if let Some(destination) = registration.url { - let hs_token = registration.hs_token.as_str(); + if let Some(destination) = registration.url { + let hs_token = registration.hs_token.as_str(); - let mut http_request = request - .try_into_http_request::( - &destination, - SendAccessToken::IfRequired(hs_token), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", destination, e); - Error::BadServerResponse("Invalid destination") - }) - .unwrap() - .map(bytes::BytesMut::freeze); + let mut http_request = request + .try_into_http_request::( + &destination, + SendAccessToken::IfRequired(hs_token), + &[MatrixVersion::V1_0], + ) + .map_err(|e| { + warn!("Failed to find destination {}: {}", destination, e); + Error::BadServerResponse("Invalid destination") + }) + .unwrap() + .map(bytes::BytesMut::freeze); - let mut parts = http_request.uri().clone().into_parts(); - let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); - let symbol = if old_path_and_query.contains('?') { - "&" - } else { - "?" - }; + let mut parts = http_request.uri().clone().into_parts(); + let old_path_and_query = parts.path_and_query.unwrap().as_str().to_owned(); + let symbol = if old_path_and_query.contains('?') { + "&" + } else { + "?" + }; - parts.path_and_query = Some( - (old_path_and_query + symbol + "access_token=" + hs_token) - .parse() - .unwrap(), - ); - *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); + parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap()); + *http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid"); - let mut reqwest_request = reqwest::Request::try_from(http_request) - .expect("all http requests are valid reqwest requests"); + let mut reqwest_request = + reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests"); - *reqwest_request.timeout_mut() = Some(Duration::from_secs(120)); + *reqwest_request.timeout_mut() = Some(Duration::from_secs(120)); - let url = reqwest_request.url().clone(); - let mut response = match services() - .globals - .default_client() - .execute(reqwest_request) - .await - { - Ok(r) => r, - Err(e) => { - warn!( - "Could not send request to appservice {} at {}: {}", - registration.id, destination, e - ); - return Some(Err(e.into())); - } - }; + let url = reqwest_request.url().clone(); + let mut response = match services().globals.default_client().execute(reqwest_request).await { + Ok(r) => r, + Err(e) => { + warn!( + "Could not send request to appservice {} at {}: {}", + registration.id, destination, e + ); + return Some(Err(e.into())); + }, + }; - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder().status(status).version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder.headers_mut().expect("http::response::Builder is usable"), + ); - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error: {}", e); - Vec::new().into() - }); // TODO: handle timeout + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error: {}", e); + Vec::new().into() + }); // TODO: handle timeout - if !status.is_success() { - warn!( - "Appservice returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - utils::string_from_bytes(&body) - ); - } + if !status.is_success() { + warn!( + "Appservice returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + utils::string_from_bytes(&body) + ); + } - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - Some(response.map_err(|_| { - warn!( - "Appservice returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Server returned bad response.") - })) - } else { - None - } + let response = T::IncomingResponse::try_from_http_response( + http_response_builder.body(body).expect("reqwest body is valid http body"), + ); + Some(response.map_err(|_| { + warn!("Appservice returned invalid response bytes {}\n{}", destination, url); + Error::BadServerResponse("Server returned bad response.") + })) + } else { + None + } } diff --git a/src/api/client_server/account.rs b/src/api/client_server/account.rs index f77ce428..fa327f56 100644 --- a/src/api/client_server/account.rs +++ b/src/api/client_server/account.rs @@ -1,21 +1,21 @@ -use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; -use crate::{api::client_server, services, utils, Error, Result, Ruma}; +use register::RegistrationKind; use ruma::{ - api::client::{ - account::{ - change_password, deactivate, get_3pids, get_username_availability, register, - request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn, - whoami, ThirdPartyIdRemovalStatus, - }, - error::ErrorKind, - uiaa::{AuthFlow, AuthType, UiaaInfo}, - }, - events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, - push, UserId, + api::client::{ + account::{ + change_password, deactivate, get_3pids, get_username_availability, register, + request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn, whoami, + ThirdPartyIdRemovalStatus, + }, + error::ErrorKind, + uiaa::{AuthFlow, AuthType, UiaaInfo}, + }, + events::{room::message::RoomMessageEventContent, GlobalAccountDataEventType}, + push, UserId, }; use tracing::{info, warn}; -use register::RegistrationKind; +use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH}; +use crate::{api::client_server, services, utils, Error, Result, Ruma}; const RANDOM_USER_ID_LENGTH: usize = 10; @@ -28,303 +28,266 @@ const RANDOM_USER_ID_LENGTH: usize = 10; /// - The server name of the user id matches this server /// - No user or appservice on this server already claimed this username /// -/// Note: This will not reserve the username, so the username might become invalid when trying to register +/// Note: This will not reserve the username, so the username might become +/// invalid when trying to register pub async fn get_register_available_route( - body: Ruma, + body: Ruma, ) -> Result { - // Validate user id - let user_id = UserId::parse_with_server_name( - body.username.to_lowercase(), - services().globals.server_name(), - ) - .ok() - .filter(|user_id| { - !user_id.is_historical() && user_id.server_name() == services().globals.server_name() - }) - .ok_or(Error::BadRequest( - ErrorKind::InvalidUsername, - "Username is invalid.", - ))?; + // Validate user id + let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name()) + .ok() + .filter(|user_id| !user_id.is_historical() && user_id.server_name() == services().globals.server_name()) + .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; - // Check if username is creative enough - if services().users.exists(&user_id)? { - return Err(Error::BadRequest( - ErrorKind::UserInUse, - "Desired user ID is already taken.", - )); - } + // Check if username is creative enough + if services().users.exists(&user_id)? { + return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); + } - if services() - .globals - .forbidden_usernames() - .is_match(user_id.localpart()) - { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Username is forbidden.", - )); - } + if services().globals.forbidden_usernames().is_match(user_id.localpart()) { + return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden.")); + } - // TODO add check for appservice namespaces + // TODO add check for appservice namespaces - // If no if check is true we have an username that's available to be used. - Ok(get_username_availability::v3::Response { available: true }) + // If no if check is true we have an username that's available to be used. + Ok(get_username_availability::v3::Response { + available: true, + }) } /// # `POST /_matrix/client/v3/register` /// /// Register an account on this homeserver. /// -/// You can use [`GET /_matrix/client/v3/register/available`](fn.get_register_available_route.html) -/// to check if the user id is valid and available. +/// You can use [`GET +/// /_matrix/client/v3/register/available`](fn.get_register_available_route. +/// html) to check if the user id is valid and available. /// /// - Only works if registration is enabled -/// - If type is guest: ignores all parameters except initial_device_display_name +/// - If type is guest: ignores all parameters except +/// initial_device_display_name /// - If sender is not appservice: Requires UIAA (but we only use a dummy stage) -/// - If type is not guest and no username is given: Always fails after UIAA check +/// - If type is not guest and no username is given: Always fails after UIAA +/// check /// - Creates a new account and populates it with default account data -/// - If `inhibit_login` is false: Creates a device and returns device id and access_token +/// - If `inhibit_login` is false: Creates a device and returns device id and +/// access_token pub async fn register_route(body: Ruma) -> Result { - if !services().globals.allow_registration() && !body.from_appservice { - info!("Registration disabled and request not from known appservice, rejecting registration attempt for username {:?}", body.username); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Registration has been disabled.", - )); - } + if !services().globals.allow_registration() && !body.from_appservice { + info!( + "Registration disabled and request not from known appservice, rejecting registration attempt for username \ + {:?}", + body.username + ); + return Err(Error::BadRequest(ErrorKind::Forbidden, "Registration has been disabled.")); + } - let is_guest = body.kind == RegistrationKind::Guest; + let is_guest = body.kind == RegistrationKind::Guest; - if is_guest - && (!services().globals.allow_guest_registration() - || (services().globals.allow_registration() - && services().globals.config.registration_token.is_some())) - { - info!("Guest registration disabled / registration enabled with token configured, rejecting guest registration, initial device name: {:?}", body.initial_device_display_name); - return Err(Error::BadRequest( - ErrorKind::GuestAccessForbidden, - "Guest registration is disabled.", - )); - } + if is_guest + && (!services().globals.allow_guest_registration() + || (services().globals.allow_registration() && services().globals.config.registration_token.is_some())) + { + info!( + "Guest registration disabled / registration enabled with token configured, rejecting guest registration, \ + initial device name: {:?}", + body.initial_device_display_name + ); + return Err(Error::BadRequest( + ErrorKind::GuestAccessForbidden, + "Guest registration is disabled.", + )); + } - // forbid guests from registering if there is not a real admin user yet. give generic user error. - if is_guest && services().users.count()? < 2 { - warn!("Guest account attempted to register before a real admin user has been registered, rejecting registration. Guest's initial device name: {:?}", body.initial_device_display_name); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Registration temporarily disabled.", - )); - } + // forbid guests from registering if there is not a real admin user yet. give + // generic user error. + if is_guest && services().users.count()? < 2 { + warn!( + "Guest account attempted to register before a real admin user has been registered, rejecting \ + registration. Guest's initial device name: {:?}", + body.initial_device_display_name + ); + return Err(Error::BadRequest(ErrorKind::Forbidden, "Registration temporarily disabled.")); + } - let user_id = match (&body.username, is_guest) { - (Some(username), false) => { - let proposed_user_id = UserId::parse_with_server_name( - username.to_lowercase(), - services().globals.server_name(), - ) - .ok() - .filter(|user_id| { - !user_id.is_historical() - && user_id.server_name() == services().globals.server_name() - }) - .ok_or(Error::BadRequest( - ErrorKind::InvalidUsername, - "Username is invalid.", - ))?; + let user_id = match (&body.username, is_guest) { + (Some(username), false) => { + let proposed_user_id = + UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name()) + .ok() + .filter(|user_id| { + !user_id.is_historical() && user_id.server_name() == services().globals.server_name() + }) + .ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?; - if services().users.exists(&proposed_user_id)? { - return Err(Error::BadRequest( - ErrorKind::UserInUse, - "Desired user ID is already taken.", - )); - } + if services().users.exists(&proposed_user_id)? { + return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken.")); + } - if services() - .globals - .forbidden_usernames() - .is_match(proposed_user_id.localpart()) - { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Username is forbidden.", - )); - } + if services().globals.forbidden_usernames().is_match(proposed_user_id.localpart()) { + return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden.")); + } - proposed_user_id - } - _ => loop { - let proposed_user_id = UserId::parse_with_server_name( - utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), - services().globals.server_name(), - ) - .unwrap(); - if !services().users.exists(&proposed_user_id)? { - break proposed_user_id; - } - }, - }; + proposed_user_id + }, + _ => loop { + let proposed_user_id = UserId::parse_with_server_name( + utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(), + services().globals.server_name(), + ) + .unwrap(); + if !services().users.exists(&proposed_user_id)? { + break proposed_user_id; + } + }, + }; - // UIAA - let mut uiaainfo; - let skip_auth; - if services().globals.config.registration_token.is_some() { - // Registration token required - uiaainfo = UiaaInfo { - flows: vec![AuthFlow { - stages: vec![AuthType::RegistrationToken], - }], - completed: Vec::new(), - params: Box::default(), - session: None, - auth_error: None, - }; - skip_auth = body.from_appservice; - } else { - // No registration token necessary, but clients must still go through the flow - uiaainfo = UiaaInfo { - flows: vec![AuthFlow { - stages: vec![AuthType::Dummy], - }], - completed: Vec::new(), - params: Box::default(), - session: None, - auth_error: None, - }; - skip_auth = body.from_appservice || is_guest; - } + // UIAA + let mut uiaainfo; + let skip_auth; + if services().globals.config.registration_token.is_some() { + // Registration token required + uiaainfo = UiaaInfo { + flows: vec![AuthFlow { + stages: vec![AuthType::RegistrationToken], + }], + completed: Vec::new(), + params: Box::default(), + session: None, + auth_error: None, + }; + skip_auth = body.from_appservice; + } else { + // No registration token necessary, but clients must still go through the flow + uiaainfo = UiaaInfo { + flows: vec![AuthFlow { + stages: vec![AuthType::Dummy], + }], + completed: Vec::new(), + params: Box::default(), + session: None, + auth_error: None, + }; + skip_auth = body.from_appservice || is_guest; + } - if !skip_auth { - if let Some(auth) = &body.auth { - let (worked, uiaainfo) = services().uiaa.try_auth( - &UserId::parse_with_server_name("", services().globals.server_name()) - .expect("we know this is valid"), - "".into(), - auth, - &uiaainfo, - )?; - if !worked { - return Err(Error::Uiaa(uiaainfo)); - } - // Success! - } else if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services().uiaa.create( - &UserId::parse_with_server_name("", services().globals.server_name()) - .expect("we know this is valid"), - "".into(), - &uiaainfo, - &json, - )?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } - } + if !skip_auth { + if let Some(auth) = &body.auth { + let (worked, uiaainfo) = services().uiaa.try_auth( + &UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"), + "".into(), + auth, + &uiaainfo, + )?; + if !worked { + return Err(Error::Uiaa(uiaainfo)); + } + // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + services().uiaa.create( + &UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"), + "".into(), + &uiaainfo, + &json, + )?; + return Err(Error::Uiaa(uiaainfo)); + } else { + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); + } + } - let password = if is_guest { - None - } else { - body.password.as_deref() - }; + let password = if is_guest { + None + } else { + body.password.as_deref() + }; - // Create user - services().users.create(&user_id, password)?; + // Create user + services().users.create(&user_id, password)?; - // Default to pretty displayname - let mut displayname = user_id.localpart().to_owned(); + // Default to pretty displayname + let mut displayname = user_id.localpart().to_owned(); - // If `new_user_displayname_suffix` is set, registration will push whatever content is set to the user's display name with a space before it - if !services().globals.new_user_displayname_suffix().is_empty() { - displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix())); - } + // If `new_user_displayname_suffix` is set, registration will push whatever + // content is set to the user's display name with a space before it + if !services().globals.new_user_displayname_suffix().is_empty() { + displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix())); + } - services() - .users - .set_displayname(&user_id, Some(displayname.clone())) - .await?; + services().users.set_displayname(&user_id, Some(displayname.clone())).await?; - // Initial account data - services().account_data.update( - None, - &user_id, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: push::Ruleset::server_default(&user_id), - }, - }) - .expect("to json always works"), - )?; + // Initial account data + services().account_data.update( + None, + &user_id, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json always works"), + )?; - // Inhibit login does not work for guests - if !is_guest && body.inhibit_login { - return Ok(register::v3::Response { - access_token: None, - user_id, - device_id: None, - refresh_token: None, - expires_in: None, - }); - } + // Inhibit login does not work for guests + if !is_guest && body.inhibit_login { + return Ok(register::v3::Response { + access_token: None, + user_id, + device_id: None, + refresh_token: None, + expires_in: None, + }); + } - // Generate new device id if the user didn't specify one - let device_id = if is_guest { - None - } else { - body.device_id.clone() - } - .unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into()); + // Generate new device id if the user didn't specify one + let device_id = if is_guest { + None + } else { + body.device_id.clone() + } + .unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into()); - // Generate new token for the device - let token = utils::random_string(TOKEN_LENGTH); + // Generate new token for the device + let token = utils::random_string(TOKEN_LENGTH); - // Create device for this account - services().users.create_device( - &user_id, - &device_id, - &token, - body.initial_device_display_name.clone(), - )?; + // Create device for this account + services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?; - info!("New user \"{}\" registered on this server.", user_id); + info!("New user \"{}\" registered on this server.", user_id); - // log in conduit admin channel if a non-guest user registered - if !body.from_appservice && !is_guest { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "New user \"{user_id}\" registered on this server." - ))); - } + // log in conduit admin channel if a non-guest user registered + if !body.from_appservice && !is_guest { + services().admin.send_message(RoomMessageEventContent::notice_plain(format!( + "New user \"{user_id}\" registered on this server." + ))); + } - // log in conduit admin channel if a guest registered - if !body.from_appservice && is_guest { - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "Guest user \"{user_id}\" with device display name `{:?}` registered on this server.", - body.initial_device_display_name - ))); - } + // log in conduit admin channel if a guest registered + if !body.from_appservice && is_guest { + services().admin.send_message(RoomMessageEventContent::notice_plain(format!( + "Guest user \"{user_id}\" with device display name `{:?}` registered on this server.", + body.initial_device_display_name + ))); + } - // If this is the first real user, grant them admin privileges except for guest users - // Note: the server user, @conduit:servername, is generated first - if services().users.count()? == 2 && !is_guest { - services() - .admin - .make_user_admin(&user_id, displayname) - .await?; + // If this is the first real user, grant them admin privileges except for guest + // users Note: the server user, @conduit:servername, is generated first + if services().users.count()? == 2 && !is_guest { + services().admin.make_user_admin(&user_id, displayname).await?; - warn!("Granting {} admin privileges as the first user", user_id); - } + warn!("Granting {} admin privileges as the first user", user_id); + } - Ok(register::v3::Response { - access_token: Some(token), - user_id, - device_id: Some(device_id), - refresh_token: None, - expires_in: None, - }) + Ok(register::v3::Response { + access_token: Some(token), + user_id, + device_id: Some(device_id), + refresh_token: None, + expires_in: None, + }) } /// # `POST /_matrix/client/r0/account/password` @@ -333,73 +296,65 @@ pub async fn register_route(body: Ruma) -> Result, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); +pub async fn change_password_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - let mut uiaainfo = UiaaInfo { - flows: vec![AuthFlow { - stages: vec![AuthType::Password], - }], - completed: Vec::new(), - params: Box::default(), - session: None, - auth_error: None, - }; + let mut uiaainfo = UiaaInfo { + flows: vec![AuthFlow { + stages: vec![AuthType::Password], + }], + completed: Vec::new(), + params: Box::default(), + session: None, + auth_error: None, + }; - if let Some(auth) = &body.auth { - let (worked, uiaainfo) = - services() - .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; - if !worked { - return Err(Error::Uiaa(uiaainfo)); - } - // Success! - } else if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() - .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + if let Some(auth) = &body.auth { + let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; + if !worked { + return Err(Error::Uiaa(uiaainfo)); + } + // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; + return Err(Error::Uiaa(uiaainfo)); + } else { + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); + } - services() - .users - .set_password(sender_user, Some(&body.new_password))?; + services().users.set_password(sender_user, Some(&body.new_password))?; - if body.logout_devices { - // Logout all devices except the current one - for id in services() - .users - .all_device_ids(sender_user) - .filter_map(std::result::Result::ok) - .filter(|id| id != sender_device) - { - services().users.remove_device(sender_user, &id)?; - } - } + if body.logout_devices { + // Logout all devices except the current one + for id in services() + .users + .all_device_ids(sender_user) + .filter_map(std::result::Result::ok) + .filter(|id| id != sender_device) + { + services().users.remove_device(sender_user, &id)?; + } + } - info!("User {} changed their password.", sender_user); - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "User {sender_user} changed their password." - ))); + info!("User {} changed their password.", sender_user); + services().admin.send_message(RoomMessageEventContent::notice_plain(format!( + "User {sender_user} changed their password." + ))); - Ok(change_password::v3::Response {}) + Ok(change_password::v3::Response {}) } /// # `GET _matrix/client/r0/account/whoami` @@ -408,14 +363,14 @@ pub async fn change_password_route( /// /// Note: Also works for Application Services pub async fn whoami_route(body: Ruma) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let device_id = body.sender_device.clone(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let device_id = body.sender_device.clone(); - Ok(whoami::v3::Response { - user_id: sender_user.clone(), - device_id, - is_guest: services().users.is_deactivated(sender_user)? && !body.from_appservice, - }) + Ok(whoami::v3::Response { + user_id: sender_user.clone(), + device_id, + is_guest: services().users.is_deactivated(sender_user)? && !body.from_appservice, + }) } /// # `POST /_matrix/client/r0/account/deactivate` @@ -424,61 +379,53 @@ pub async fn whoami_route(body: Ruma) -> Result, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); +pub async fn deactivate_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - let mut uiaainfo = UiaaInfo { - flows: vec![AuthFlow { - stages: vec![AuthType::Password], - }], - completed: Vec::new(), - params: Box::default(), - session: None, - auth_error: None, - }; + let mut uiaainfo = UiaaInfo { + flows: vec![AuthFlow { + stages: vec![AuthType::Password], + }], + completed: Vec::new(), + params: Box::default(), + session: None, + auth_error: None, + }; - if let Some(auth) = &body.auth { - let (worked, uiaainfo) = - services() - .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; - if !worked { - return Err(Error::Uiaa(uiaainfo)); - } - // Success! - } else if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() - .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + if let Some(auth) = &body.auth { + let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; + if !worked { + return Err(Error::Uiaa(uiaainfo)); + } + // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; + return Err(Error::Uiaa(uiaainfo)); + } else { + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); + } - // Make the user leave all rooms before deactivation - client_server::leave_all_rooms(sender_user).await?; + // Make the user leave all rooms before deactivation + client_server::leave_all_rooms(sender_user).await?; - // Remove devices and mark account as deactivated - services().users.deactivate_account(sender_user)?; + // Remove devices and mark account as deactivated + services().users.deactivate_account(sender_user)?; - info!("User {} deactivated their account.", sender_user); - services() - .admin - .send_message(RoomMessageEventContent::notice_plain(format!( - "User {sender_user} deactivated their account." - ))); + info!("User {} deactivated their account.", sender_user); + services().admin.send_message(RoomMessageEventContent::notice_plain(format!( + "User {sender_user} deactivated their account." + ))); - Ok(deactivate::v3::Response { - id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, - }) + Ok(deactivate::v3::Response { + id_server_unbind_result: ThirdPartyIdRemovalStatus::NoSupport, + }) } /// # `GET _matrix/client/v3/account/3pid` @@ -486,38 +433,40 @@ pub async fn deactivate_route( /// Get a list of third party identifiers associated with this account. /// /// - Currently always returns empty list -pub async fn third_party_route( - body: Ruma, -) -> Result { - let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn third_party_route(body: Ruma) -> Result { + let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); - Ok(get_3pids::v3::Response::new(Vec::new())) + Ok(get_3pids::v3::Response::new(Vec::new())) } /// # `POST /_matrix/client/v3/account/3pid/email/requestToken` /// -/// "This API should be used to request validation tokens when adding an email address to an account" +/// "This API should be used to request validation tokens when adding an email +/// address to an account" /// -/// - 403 signals that The homeserver does not allow the third party identifier as a contact option. +/// - 403 signals that The homeserver does not allow the third party identifier +/// as a contact option. pub async fn request_3pid_management_token_via_email_route( - _body: Ruma, + _body: Ruma, ) -> Result { - Err(Error::BadRequest( - ErrorKind::ThreepidDenied, - "Third party identifier is not allowed", - )) + Err(Error::BadRequest( + ErrorKind::ThreepidDenied, + "Third party identifier is not allowed", + )) } /// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken` /// -/// "This API should be used to request validation tokens when adding an phone number to an account" +/// "This API should be used to request validation tokens when adding an phone +/// number to an account" /// -/// - 403 signals that The homeserver does not allow the third party identifier as a contact option. +/// - 403 signals that The homeserver does not allow the third party identifier +/// as a contact option. pub async fn request_3pid_management_token_via_msisdn_route( - _body: Ruma, + _body: Ruma, ) -> Result { - Err(Error::BadRequest( - ErrorKind::ThreepidDenied, - "Third party identifier is not allowed", - )) + Err(Error::BadRequest( + ErrorKind::ThreepidDenied, + "Third party identifier is not allowed", + )) } diff --git a/src/api/client_server/alias.rs b/src/api/client_server/alias.rs index 230a23da..0e930fc3 100644 --- a/src/api/client_server/alias.rs +++ b/src/api/client_server/alias.rs @@ -1,64 +1,43 @@ -use crate::{services, Error, Result, Ruma}; use rand::seq::SliceRandom; use regex::Regex; use ruma::{ - api::{ - appservice, - client::{ - alias::{create_alias, delete_alias, get_alias}, - error::ErrorKind, - }, - federation, - }, - OwnedRoomAliasId, OwnedServerName, + api::{ + appservice, + client::{ + alias::{create_alias, delete_alias, get_alias}, + error::ErrorKind, + }, + federation, + }, + OwnedRoomAliasId, OwnedServerName, }; +use crate::{services, Error, Result, Ruma}; + /// # `PUT /_matrix/client/v3/directory/room/{roomAlias}` /// /// Creates a new room alias on this server. -pub async fn create_alias_route( - body: Ruma, -) -> Result { - if body.room_alias.server_name() != services().globals.server_name() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Alias is from another server.", - )); - } +pub async fn create_alias_route(body: Ruma) -> Result { + if body.room_alias.server_name() != services().globals.server_name() { + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); + } - if services() - .globals - .forbidden_room_names() - .is_match(body.room_alias.alias()) - { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Room alias is forbidden.", - )); - } + if services().globals.forbidden_room_names().is_match(body.room_alias.alias()) { + return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias is forbidden.")); + } - if services() - .rooms - .alias - .resolve_local_alias(&body.room_alias)? - .is_some() - { - return Err(Error::Conflict("Alias already exists.")); - } + if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() { + return Err(Error::Conflict("Alias already exists.")); + } - if services() - .rooms - .alias - .set_alias(&body.room_alias, &body.room_id) - .is_err() - { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid room alias. Alias must be in the form of '#localpart:server_name'", - )); - }; + if services().rooms.alias.set_alias(&body.room_alias, &body.room_id).is_err() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid room alias. Alias must be in the form of '#localpart:server_name'", + )); + }; - Ok(create_alias::v3::Response::new()) + Ok(create_alias::v3::Response::new()) } /// # `DELETE /_matrix/client/v3/directory/room/{roomAlias}` @@ -67,183 +46,137 @@ pub async fn create_alias_route( /// /// - TODO: additional access control checks /// - TODO: Update canonical alias event -pub async fn delete_alias_route( - body: Ruma, -) -> Result { - if body.room_alias.server_name() != services().globals.server_name() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Alias is from another server.", - )); - } +pub async fn delete_alias_route(body: Ruma) -> Result { + if body.room_alias.server_name() != services().globals.server_name() { + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server.")); + } - if services() - .rooms - .alias - .resolve_local_alias(&body.room_alias)? - .is_none() - { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Alias does not exist.", - )); - } + if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_none() { + return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist.")); + } - if services() - .rooms - .alias - .remove_alias(&body.room_alias) - .is_err() - { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid room alias. Alias must be in the form of '#localpart:server_name'", - )); - }; + if services().rooms.alias.remove_alias(&body.room_alias).is_err() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid room alias. Alias must be in the form of '#localpart:server_name'", + )); + }; - // TODO: update alt_aliases? + // TODO: update alt_aliases? - Ok(delete_alias::v3::Response::new()) + Ok(delete_alias::v3::Response::new()) } /// # `GET /_matrix/client/v3/directory/room/{roomAlias}` /// /// Resolve an alias locally or over federation. -pub async fn get_alias_route( - body: Ruma, -) -> Result { - get_alias_helper(body.body.room_alias).await +pub async fn get_alias_route(body: Ruma) -> Result { + get_alias_helper(body.body.room_alias).await } -pub(crate) async fn get_alias_helper( - room_alias: OwnedRoomAliasId, -) -> Result { - if room_alias.server_name() != services().globals.server_name() { - let response = services() - .sending - .send_federation_request( - room_alias.server_name(), - federation::query::get_room_information::v1::Request { - room_alias: room_alias.clone(), - }, - ) - .await?; +pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result { + if room_alias.server_name() != services().globals.server_name() { + let response = services() + .sending + .send_federation_request( + room_alias.server_name(), + federation::query::get_room_information::v1::Request { + room_alias: room_alias.clone(), + }, + ) + .await?; - let room_id = response.room_id; + let room_id = response.room_id; - let mut servers = response.servers; + let mut servers = response.servers; - // find active servers in room state cache to suggest - for extra_servers in services() - .rooms - .state_cache - .room_servers(&room_id) - .filter_map(std::result::Result::ok) - { - servers.push(extra_servers); - } + // find active servers in room state cache to suggest + for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(std::result::Result::ok) { + servers.push(extra_servers); + } - // insert our server as the very first choice if in list - if let Some(server_index) = servers - .clone() - .into_iter() - .position(|server| server == services().globals.server_name()) - { - servers.remove(server_index); - servers.insert(0, services().globals.server_name().to_owned()); - } + // insert our server as the very first choice if in list + if let Some(server_index) = + servers.clone().into_iter().position(|server| server == services().globals.server_name()) + { + servers.remove(server_index); + servers.insert(0, services().globals.server_name().to_owned()); + } - servers.sort_unstable(); - servers.dedup(); + servers.sort_unstable(); + servers.dedup(); - // shuffle list of servers randomly after sort and dedupe - servers.shuffle(&mut rand::thread_rng()); + // shuffle list of servers randomly after sort and dedupe + servers.shuffle(&mut rand::thread_rng()); - return Ok(get_alias::v3::Response::new(room_id, servers)); - } + return Ok(get_alias::v3::Response::new(room_id, servers)); + } - let mut room_id = None; - match services().rooms.alias.resolve_local_alias(&room_alias)? { - Some(r) => room_id = Some(r), - None => { - for (_id, registration) in services().appservice.all()? { - let aliases = registration - .namespaces - .aliases - .iter() - .filter_map(|alias| Regex::new(alias.regex.as_str()).ok()) - .collect::>(); + let mut room_id = None; + match services().rooms.alias.resolve_local_alias(&room_alias)? { + Some(r) => room_id = Some(r), + None => { + for (_id, registration) in services().appservice.all()? { + let aliases = registration + .namespaces + .aliases + .iter() + .filter_map(|alias| Regex::new(alias.regex.as_str()).ok()) + .collect::>(); - if aliases - .iter() - .any(|aliases| aliases.is_match(room_alias.as_str())) - && if let Some(opt_result) = services() - .sending - .send_appservice_request( - registration, - appservice::query::query_room_alias::v1::Request { - room_alias: room_alias.clone(), - }, - ) - .await - { - opt_result.is_ok() - } else { - false - } - { - room_id = Some( - services() - .rooms - .alias - .resolve_local_alias(&room_alias)? - .ok_or_else(|| { - Error::bad_config("Appservice lied to us. Room does not exist.") - })?, - ); - break; - } - } - } - }; + if aliases.iter().any(|aliases| aliases.is_match(room_alias.as_str())) + && if let Some(opt_result) = services() + .sending + .send_appservice_request( + registration, + appservice::query::query_room_alias::v1::Request { + room_alias: room_alias.clone(), + }, + ) + .await + { + opt_result.is_ok() + } else { + false + } { + room_id = Some( + services() + .rooms + .alias + .resolve_local_alias(&room_alias)? + .ok_or_else(|| Error::bad_config("Appservice lied to us. Room does not exist."))?, + ); + break; + } + } + }, + }; - let room_id = match room_id { - Some(room_id) => room_id, - None => { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Room with alias not found.", - )) - } - }; + let room_id = match room_id { + Some(room_id) => room_id, + None => return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")), + }; - let mut servers: Vec = Vec::new(); + let mut servers: Vec = Vec::new(); - // find active servers in room state cache to suggest - for extra_servers in services() - .rooms - .state_cache - .room_servers(&room_id) - .filter_map(std::result::Result::ok) - { - servers.push(extra_servers); - } + // find active servers in room state cache to suggest + for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(std::result::Result::ok) { + servers.push(extra_servers); + } - // insert our server as the very first choice if in list - if let Some(server_index) = servers - .clone() - .into_iter() - .position(|server| server == services().globals.server_name()) - { - servers.remove(server_index); - servers.insert(0, services().globals.server_name().to_owned()); - } + // insert our server as the very first choice if in list + if let Some(server_index) = + servers.clone().into_iter().position(|server| server == services().globals.server_name()) + { + servers.remove(server_index); + servers.insert(0, services().globals.server_name().to_owned()); + } - servers.sort_unstable(); - servers.dedup(); + servers.sort_unstable(); + servers.dedup(); - // shuffle list of servers randomly after sort and dedupe - servers.shuffle(&mut rand::thread_rng()); + // shuffle list of servers randomly after sort and dedupe + servers.shuffle(&mut rand::thread_rng()); - Ok(get_alias::v3::Response::new(room_id, servers)) + Ok(get_alias::v3::Response::new(room_id, servers)) } diff --git a/src/api/client_server/backup.rs b/src/api/client_server/backup.rs index 8bbe3ef1..3e35da4f 100644 --- a/src/api/client_server/backup.rs +++ b/src/api/client_server/backup.rs @@ -1,362 +1,275 @@ -use crate::{services, Error, Result, Ruma}; use ruma::api::client::{ - backup::{ - add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, - create_backup_version, delete_backup_keys, delete_backup_keys_for_room, - delete_backup_keys_for_session, delete_backup_version, get_backup_info, get_backup_keys, - get_backup_keys_for_room, get_backup_keys_for_session, get_latest_backup_info, - update_backup_version, - }, - error::ErrorKind, + backup::{ + add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version, + delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version, + get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session, + get_latest_backup_info, update_backup_version, + }, + error::ErrorKind, }; +use crate::{services, Error, Result, Ruma}; + /// # `POST /_matrix/client/r0/room_keys/version` /// /// Creates a new backup. pub async fn create_backup_version_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let version = services() - .key_backups - .create_backup(sender_user, &body.algorithm)?; + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let version = services().key_backups.create_backup(sender_user, &body.algorithm)?; - Ok(create_backup_version::v3::Response { version }) + Ok(create_backup_version::v3::Response { + version, + }) } /// # `PUT /_matrix/client/r0/room_keys/version/{version}` /// -/// Update information about an existing backup. Only `auth_data` can be modified. +/// Update information about an existing backup. Only `auth_data` can be +/// modified. pub async fn update_backup_version_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .key_backups - .update_backup(sender_user, &body.version, &body.algorithm)?; + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + services().key_backups.update_backup(sender_user, &body.version, &body.algorithm)?; - Ok(update_backup_version::v3::Response {}) + Ok(update_backup_version::v3::Response {}) } /// # `GET /_matrix/client/r0/room_keys/version` /// /// Get information about the latest backup version. pub async fn get_latest_backup_info_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let (version, algorithm) = services() - .key_backups - .get_latest_backup(sender_user)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Key backup does not exist.", - ))?; + let (version, algorithm) = services() + .key_backups + .get_latest_backup(sender_user)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; - Ok(get_latest_backup_info::v3::Response { - algorithm, - count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(), - etag: services().key_backups.get_etag(sender_user, &version)?, - version, - }) + Ok(get_latest_backup_info::v3::Response { + algorithm, + count: (services().key_backups.count_keys(sender_user, &version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &version)?, + version, + }) } /// # `GET /_matrix/client/r0/room_keys/version` /// /// Get information about an existing backup. -pub async fn get_backup_info_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let algorithm = services() - .key_backups - .get_backup(sender_user, &body.version)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Key backup does not exist.", - ))?; +pub async fn get_backup_info_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let algorithm = services() + .key_backups + .get_backup(sender_user, &body.version)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?; - Ok(get_backup_info::v3::Response { - algorithm, - count: (services() - .key_backups - .count_keys(sender_user, &body.version)? as u32) - .into(), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, - version: body.version.clone(), - }) + Ok(get_backup_info::v3::Response { + algorithm, + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, + version: body.version.clone(), + }) } /// # `DELETE /_matrix/client/r0/room_keys/version/{version}` /// /// Delete an existing key backup. /// -/// - Deletes both information about the backup, as well as all key data related to the backup +/// - Deletes both information about the backup, as well as all key data related +/// to the backup pub async fn delete_backup_version_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .key_backups - .delete_backup(sender_user, &body.version)?; + services().key_backups.delete_backup(sender_user, &body.version)?; - Ok(delete_backup_version::v3::Response {}) + Ok(delete_backup_version::v3::Response {}) } /// # `PUT /_matrix/client/r0/room_keys/keys` /// /// Add the received backup keys to the database. /// -/// - Only manipulating the most recently created version of the backup is allowed +/// - Only manipulating the most recently created version of the backup is +/// allowed /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag -pub async fn add_backup_keys_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn add_backup_keys_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services() - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() - { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); - } + if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "You may only manipulate the most recently created version of the backup.", + )); + } - for (room_id, room) in &body.rooms { - for (session_id, key_data) in &room.sessions { - services().key_backups.add_key( - sender_user, - &body.version, - room_id, - session_id, - key_data, - )?; - } - } + for (room_id, room) in &body.rooms { + for (session_id, key_data) in &room.sessions { + services().key_backups.add_key(sender_user, &body.version, room_id, session_id, key_data)?; + } + } - Ok(add_backup_keys::v3::Response { - count: (services() - .key_backups - .count_keys(sender_user, &body.version)? as u32) - .into(), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, - }) + Ok(add_backup_keys::v3::Response { + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, + }) } /// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}` /// /// Add the received backup keys to the database. /// -/// - Only manipulating the most recently created version of the backup is allowed +/// - Only manipulating the most recently created version of the backup is +/// allowed /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_for_room_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services() - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() - { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); - } + if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "You may only manipulate the most recently created version of the backup.", + )); + } - for (session_id, key_data) in &body.sessions { - services().key_backups.add_key( - sender_user, - &body.version, - &body.room_id, - session_id, - key_data, - )?; - } + for (session_id, key_data) in &body.sessions { + services().key_backups.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?; + } - Ok(add_backup_keys_for_room::v3::Response { - count: (services() - .key_backups - .count_keys(sender_user, &body.version)? as u32) - .into(), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, - }) + Ok(add_backup_keys_for_room::v3::Response { + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, + }) } /// # `PUT /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` /// /// Add the received backup key to the database. /// -/// - Only manipulating the most recently created version of the backup is allowed +/// - Only manipulating the most recently created version of the backup is +/// allowed /// - Adds the keys to the backup /// - Returns the new number of keys in this backup and the etag pub async fn add_backup_keys_for_session_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if Some(&body.version) - != services() - .key_backups - .get_latest_backup_version(sender_user)? - .as_ref() - { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "You may only manipulate the most recently created version of the backup.", - )); - } + if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "You may only manipulate the most recently created version of the backup.", + )); + } - services().key_backups.add_key( - sender_user, - &body.version, - &body.room_id, - &body.session_id, - &body.session_data, - )?; + services().key_backups.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?; - Ok(add_backup_keys_for_session::v3::Response { - count: (services() - .key_backups - .count_keys(sender_user, &body.version)? as u32) - .into(), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, - }) + Ok(add_backup_keys_for_session::v3::Response { + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, + }) } /// # `GET /_matrix/client/r0/room_keys/keys` /// /// Retrieves all keys from the backup. -pub async fn get_backup_keys_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn get_backup_keys_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let rooms = services().key_backups.get_all(sender_user, &body.version)?; + let rooms = services().key_backups.get_all(sender_user, &body.version)?; - Ok(get_backup_keys::v3::Response { rooms }) + Ok(get_backup_keys::v3::Response { + rooms, + }) } /// # `GET /_matrix/client/r0/room_keys/keys/{roomId}` /// /// Retrieves all keys from the backup for a given room. pub async fn get_backup_keys_for_room_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sessions = services() - .key_backups - .get_room(sender_user, &body.version, &body.room_id)?; + let sessions = services().key_backups.get_room(sender_user, &body.version, &body.room_id)?; - Ok(get_backup_keys_for_room::v3::Response { sessions }) + Ok(get_backup_keys_for_room::v3::Response { + sessions, + }) } /// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` /// /// Retrieves a key from the backup. pub async fn get_backup_keys_for_session_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let key_data = services() - .key_backups - .get_session(sender_user, &body.version, &body.room_id, &body.session_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Backup key not found for this user's session.", - ))?; + let key_data = + services().key_backups.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?.ok_or( + Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."), + )?; - Ok(get_backup_keys_for_session::v3::Response { key_data }) + Ok(get_backup_keys_for_session::v3::Response { + key_data, + }) } /// # `DELETE /_matrix/client/r0/room_keys/keys` /// /// Delete the keys from the backup. pub async fn delete_backup_keys_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .key_backups - .delete_all_keys(sender_user, &body.version)?; + services().key_backups.delete_all_keys(sender_user, &body.version)?; - Ok(delete_backup_keys::v3::Response { - count: (services() - .key_backups - .count_keys(sender_user, &body.version)? as u32) - .into(), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, - }) + Ok(delete_backup_keys::v3::Response { + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, + }) } /// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}` /// /// Delete the keys from the backup for a given room. pub async fn delete_backup_keys_for_room_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .key_backups - .delete_room_keys(sender_user, &body.version, &body.room_id)?; + services().key_backups.delete_room_keys(sender_user, &body.version, &body.room_id)?; - Ok(delete_backup_keys_for_room::v3::Response { - count: (services() - .key_backups - .count_keys(sender_user, &body.version)? as u32) - .into(), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, - }) + Ok(delete_backup_keys_for_room::v3::Response { + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, + }) } /// # `DELETE /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}` /// /// Delete a key from the backup. pub async fn delete_backup_keys_for_session_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services().key_backups.delete_room_key( - sender_user, - &body.version, - &body.room_id, - &body.session_id, - )?; + services().key_backups.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?; - Ok(delete_backup_keys_for_session::v3::Response { - count: (services() - .key_backups - .count_keys(sender_user, &body.version)? as u32) - .into(), - etag: services() - .key_backups - .get_etag(sender_user, &body.version)?, - }) + Ok(delete_backup_keys_for_session::v3::Response { + count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(), + etag: services().key_backups.get_etag(sender_user, &body.version)?, + }) } diff --git a/src/api/client_server/capabilities.rs b/src/api/client_server/capabilities.rs index 233e3c9c..f22ba63c 100644 --- a/src/api/client_server/capabilities.rs +++ b/src/api/client_server/capabilities.rs @@ -1,28 +1,33 @@ -use crate::{services, Result, Ruma}; -use ruma::api::client::discovery::get_capabilities::{ - self, Capabilities, RoomVersionStability, RoomVersionsCapability, -}; use std::collections::BTreeMap; +use ruma::api::client::discovery::get_capabilities::{ + self, Capabilities, RoomVersionStability, RoomVersionsCapability, +}; + +use crate::{services, Result, Ruma}; + /// # `GET /_matrix/client/r0/capabilities` /// -/// Get information on the supported feature set and other relevent capabilities of this server. +/// Get information on the supported feature set and other relevent capabilities +/// of this server. pub async fn get_capabilities_route( - _body: Ruma, + _body: Ruma, ) -> Result { - let mut available = BTreeMap::new(); - for room_version in &services().globals.unstable_room_versions { - available.insert(room_version.clone(), RoomVersionStability::Unstable); - } - for room_version in &services().globals.stable_room_versions { - available.insert(room_version.clone(), RoomVersionStability::Stable); - } + let mut available = BTreeMap::new(); + for room_version in &services().globals.unstable_room_versions { + available.insert(room_version.clone(), RoomVersionStability::Unstable); + } + for room_version in &services().globals.stable_room_versions { + available.insert(room_version.clone(), RoomVersionStability::Stable); + } - let mut capabilities = Capabilities::new(); - capabilities.room_versions = RoomVersionsCapability { - default: services().globals.default_room_version(), - available, - }; + let mut capabilities = Capabilities::new(); + capabilities.room_versions = RoomVersionsCapability { + default: services().globals.default_room_version(), + available, + }; - Ok(get_capabilities::v3::Response { capabilities }) + Ok(get_capabilities::v3::Response { + capabilities, + }) } diff --git a/src/api/client_server/config.rs b/src/api/client_server/config.rs index 37279e35..247b4ef8 100644 --- a/src/api/client_server/config.rs +++ b/src/api/client_server/config.rs @@ -1,116 +1,118 @@ -use crate::{services, Error, Result, Ruma}; use ruma::{ - api::client::{ - config::{ - get_global_account_data, get_room_account_data, set_global_account_data, - set_room_account_data, - }, - error::ErrorKind, - }, - events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent}, - serde::Raw, + api::client::{ + config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data}, + error::ErrorKind, + }, + events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent}, + serde::Raw, }; use serde::Deserialize; use serde_json::{json, value::RawValue as RawJsonValue}; +use crate::{services, Error, Result, Ruma}; + /// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}` /// /// Sets some account data for the sender user. pub async fn set_global_account_data_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let data: serde_json::Value = serde_json::from_str(body.data.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; + let data: serde_json::Value = serde_json::from_str(body.data.json().get()) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; - let event_type = body.event_type.to_string(); + let event_type = body.event_type.to_string(); - services().account_data.update( - None, - sender_user, - event_type.clone().into(), - &json!({ - "type": event_type, - "content": data, - }), - )?; + services().account_data.update( + None, + sender_user, + event_type.clone().into(), + &json!({ + "type": event_type, + "content": data, + }), + )?; - Ok(set_global_account_data::v3::Response {}) + Ok(set_global_account_data::v3::Response {}) } /// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}` /// /// Sets some room account data for the sender user. pub async fn set_room_account_data_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let data: serde_json::Value = serde_json::from_str(body.data.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; + let data: serde_json::Value = serde_json::from_str(body.data.json().get()) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Data is invalid."))?; - let event_type = body.event_type.to_string(); + let event_type = body.event_type.to_string(); - services().account_data.update( - Some(&body.room_id), - sender_user, - event_type.clone().into(), - &json!({ - "type": event_type, - "content": data, - }), - )?; + services().account_data.update( + Some(&body.room_id), + sender_user, + event_type.clone().into(), + &json!({ + "type": event_type, + "content": data, + }), + )?; - Ok(set_room_account_data::v3::Response {}) + Ok(set_room_account_data::v3::Response {}) } /// # `GET /_matrix/client/r0/user/{userId}/account_data/{type}` /// /// Gets some account data for the sender user. pub async fn get_global_account_data_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box = services() - .account_data - .get(None, sender_user, body.event_type.to_string().into())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; + let event: Box = services() + .account_data + .get(None, sender_user, body.event_type.to_string().into())? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + let account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; - Ok(get_global_account_data::v3::Response { account_data }) + Ok(get_global_account_data::v3::Response { + account_data, + }) } /// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}` /// /// Gets some room account data for the sender user. pub async fn get_room_account_data_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event: Box = services() - .account_data - .get(Some(&body.room_id), sender_user, body.event_type.clone())? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; + let event: Box = services() + .account_data + .get(Some(&body.room_id), sender_user, body.event_type.clone())? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Data not found."))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + let account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; - Ok(get_room_account_data::v3::Response { account_data }) + Ok(get_room_account_data::v3::Response { + account_data, + }) } #[derive(Deserialize)] struct ExtractRoomEventContent { - content: Raw, + content: Raw, } #[derive(Deserialize)] struct ExtractGlobalEventContent { - content: Raw, + content: Raw, } diff --git a/src/api/client_server/context.rs b/src/api/client_server/context.rs index e828795f..7f7fc97f 100644 --- a/src/api/client_server/context.rs +++ b/src/api/client_server/context.rs @@ -1,209 +1,177 @@ -use crate::{services, Error, Result, Ruma}; -use ruma::{ - api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, - events::StateEventType, -}; use std::collections::HashSet; + +use ruma::{ + api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions}, + events::StateEventType, +}; use tracing::error; +use crate::{services, Error, Result, Ruma}; + /// # `GET /_matrix/client/r0/rooms/{roomId}/context` /// /// Allows loading room history around an event. /// -/// - Only works if the user is joined (TODO: always allow, but only show events if the user was +/// - Only works if the user is joined (TODO: always allow, but only show events +/// if the user was /// joined, depending on history_visibility) -pub async fn get_context_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); +pub async fn get_context_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options { - LazyLoadOptions::Enabled { - include_redundant_members, - } => (true, *include_redundant_members), - LazyLoadOptions::Disabled => (false, false), - }; + let (lazy_load_enabled, lazy_load_send_redundant) = match &body.filter.lazy_load_options { + LazyLoadOptions::Enabled { + include_redundant_members, + } => (true, *include_redundant_members), + LazyLoadOptions::Disabled => (false, false), + }; - let mut lazy_loaded = HashSet::new(); + let mut lazy_loaded = HashSet::new(); - let base_token = services() - .rooms - .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Base event id not found.", - ))?; + let base_token = services() + .rooms + .timeline + .get_pdu_count(&body.event_id)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?; - let base_event = - services() - .rooms - .timeline - .get_pdu(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Base event not found.", - ))?; + let base_event = services() + .rooms + .timeline + .get_pdu(&body.event_id)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?; - let room_id = base_event.room_id.clone(); + let room_id = base_event.room_id.clone(); - if !services() - .rooms - .state_accessor - .user_can_see_event(sender_user, &room_id, &body.event_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view this event.", - )); - } + if !services().rooms.state_accessor.user_can_see_event(sender_user, &room_id, &body.event_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this event.", + )); + } - if !services().rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &base_event.sender, - )? || lazy_load_send_redundant - { - lazy_loaded.insert(base_event.sender.as_str().to_owned()); - } + if !services().rooms.lazy_loading.lazy_load_was_sent_before( + sender_user, + sender_device, + &room_id, + &base_event.sender, + )? || lazy_load_send_redundant + { + lazy_loaded.insert(base_event.sender.as_str().to_owned()); + } - // Use limit with maximum 100 - let limit = u64::from(body.limit).min(100) as usize; + // Use limit with maximum 100 + let limit = u64::from(body.limit).min(100) as usize; - let base_event = base_event.to_room_event(); + let base_event = base_event.to_room_event(); - let events_before: Vec<_> = services() - .rooms - .timeline - .pdus_until(sender_user, &room_id, base_token)? - .take(limit / 2) - .filter_map(std::result::Result::ok) // Remove buggy events - .filter(|(_, pdu)| { - services() - .rooms - .state_accessor - .user_can_see_event(sender_user, &room_id, &pdu.event_id) - .unwrap_or(false) - }) - .collect(); + let events_before: Vec<_> = services() + .rooms + .timeline + .pdus_until(sender_user, &room_id, base_token)? + .take(limit / 2) + .filter_map(std::result::Result::ok) // Remove buggy events + .filter(|(_, pdu)| { + services() + .rooms + .state_accessor + .user_can_see_event(sender_user, &room_id, &pdu.event_id) + .unwrap_or(false) + }) + .collect(); - for (_, event) in &events_before { - if !services().rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &event.sender, - )? || lazy_load_send_redundant - { - lazy_loaded.insert(event.sender.as_str().to_owned()); - } - } + for (_, event) in &events_before { + if !services().rooms.lazy_loading.lazy_load_was_sent_before( + sender_user, + sender_device, + &room_id, + &event.sender, + )? || lazy_load_send_redundant + { + lazy_loaded.insert(event.sender.as_str().to_owned()); + } + } - let start_token = events_before - .last() - .map(|(count, _)| count.stringify()) - .unwrap_or_else(|| base_token.stringify()); + let start_token = + events_before.last().map(|(count, _)| count.stringify()).unwrap_or_else(|| base_token.stringify()); - let events_before: Vec<_> = events_before - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); + let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); - let events_after: Vec<_> = services() - .rooms - .timeline - .pdus_after(sender_user, &room_id, base_token)? - .take(limit / 2) - .filter_map(std::result::Result::ok) // Remove buggy events - .filter(|(_, pdu)| { - services() - .rooms - .state_accessor - .user_can_see_event(sender_user, &room_id, &pdu.event_id) - .unwrap_or(false) - }) - .collect(); + let events_after: Vec<_> = services() + .rooms + .timeline + .pdus_after(sender_user, &room_id, base_token)? + .take(limit / 2) + .filter_map(std::result::Result::ok) // Remove buggy events + .filter(|(_, pdu)| { + services() + .rooms + .state_accessor + .user_can_see_event(sender_user, &room_id, &pdu.event_id) + .unwrap_or(false) + }) + .collect(); - for (_, event) in &events_after { - if !services().rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &room_id, - &event.sender, - )? || lazy_load_send_redundant - { - lazy_loaded.insert(event.sender.as_str().to_owned()); - } - } + for (_, event) in &events_after { + if !services().rooms.lazy_loading.lazy_load_was_sent_before( + sender_user, + sender_device, + &room_id, + &event.sender, + )? || lazy_load_send_redundant + { + lazy_loaded.insert(event.sender.as_str().to_owned()); + } + } - let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash( - events_after - .last() - .map_or(&*body.event_id, |(_, e)| &*e.event_id), - )? { - Some(s) => s, - None => services() - .rooms - .state - .get_room_shortstatehash(&room_id)? - .expect("All rooms have state"), - }; + let shortstatehash = match services() + .rooms + .state_accessor + .pdu_shortstatehash(events_after.last().map_or(&*body.event_id, |(_, e)| &*e.event_id))? + { + Some(s) => s, + None => services().rooms.state.get_room_shortstatehash(&room_id)?.expect("All rooms have state"), + }; - let state_ids = services() - .rooms - .state_accessor - .state_full_ids(shortstatehash) - .await?; + let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?; - let end_token = events_after - .last() - .map(|(count, _)| count.stringify()) - .unwrap_or_else(|| base_token.stringify()); + let end_token = events_after.last().map(|(count, _)| count.stringify()).unwrap_or_else(|| base_token.stringify()); - let events_after: Vec<_> = events_after - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); + let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); - let mut state = Vec::new(); + let mut state = Vec::new(); - for (shortstatekey, id) in state_ids { - let (event_type, state_key) = services() - .rooms - .short - .get_statekey_from_short(shortstatekey)?; + for (shortstatekey, id) in state_ids { + let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?; - if event_type != StateEventType::RoomMember { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - } - }; - state.push(pdu.to_state_event()); - } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - } - }; - state.push(pdu.to_state_event()); - } - } + if event_type != StateEventType::RoomMember { + let pdu = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + }, + }; + state.push(pdu.to_state_event()); + } else if !lazy_load_enabled || lazy_loaded.contains(&state_key) { + let pdu = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + }, + }; + state.push(pdu.to_state_event()); + } + } - let resp = get_context::v3::Response { - start: Some(start_token), - end: Some(end_token), - events_before, - event: Some(base_event), - events_after, - state, - }; + let resp = get_context::v3::Response { + start: Some(start_token), + end: Some(end_token), + events_before, + event: Some(base_event), + events_after, + state, + }; - Ok(resp) + Ok(resp) } diff --git a/src/api/client_server/device.rs b/src/api/client_server/device.rs index df9b0c3f..d0474c38 100644 --- a/src/api/client_server/device.rs +++ b/src/api/client_server/device.rs @@ -1,65 +1,61 @@ -use crate::{services, utils, Error, Result, Ruma}; use ruma::api::client::{ - device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, - error::ErrorKind, - uiaa::{AuthFlow, AuthType, UiaaInfo}, + device::{self, delete_device, delete_devices, get_device, get_devices, update_device}, + error::ErrorKind, + uiaa::{AuthFlow, AuthType, UiaaInfo}, }; use super::SESSION_ID_LENGTH; +use crate::{services, utils, Error, Result, Ruma}; /// # `GET /_matrix/client/r0/devices` /// /// Get metadata on all devices of the sender user. -pub async fn get_devices_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn get_devices_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let devices: Vec = services() - .users - .all_devices_metadata(sender_user) - .filter_map(std::result::Result::ok) // Filter out buggy devices - .collect(); + let devices: Vec = services() + .users + .all_devices_metadata(sender_user) + .filter_map(std::result::Result::ok) // Filter out buggy devices + .collect(); - Ok(get_devices::v3::Response { devices }) + Ok(get_devices::v3::Response { + devices, + }) } /// # `GET /_matrix/client/r0/devices/{deviceId}` /// /// Get metadata on a single device of the sender user. -pub async fn get_device_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn get_device_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let device = services() - .users - .get_device_metadata(sender_user, &body.body.device_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; + let device = services() + .users + .get_device_metadata(sender_user, &body.body.device_id)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; - Ok(get_device::v3::Response { device }) + Ok(get_device::v3::Response { + device, + }) } /// # `PUT /_matrix/client/r0/devices/{deviceId}` /// /// Updates the metadata on a given device of the sender user. -pub async fn update_device_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn update_device_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut device = services() - .users - .get_device_metadata(sender_user, &body.device_id)? - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; + let mut device = services() + .users + .get_device_metadata(sender_user, &body.device_id)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?; - device.display_name = body.display_name.clone(); + device.display_name = body.display_name.clone(); - services() - .users - .update_device_metadata(sender_user, &body.device_id, &device)?; + services().users.update_device_metadata(sender_user, &body.device_id, &device)?; - Ok(update_device::v3::Response {}) + Ok(update_device::v3::Response {}) } /// # `DELETE /_matrix/client/r0/devices/{deviceId}` @@ -68,50 +64,42 @@ pub async fn update_device_route( /// /// - Requires UIAA to verify user password /// - Invalidates access token -/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) +/// - Deletes device metadata (device id, device display name, last seen ip, +/// last seen ts) /// - Forgets to-device events /// - Triggers device list updates -pub async fn delete_device_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); +pub async fn delete_device_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - // UIAA - let mut uiaainfo = UiaaInfo { - flows: vec![AuthFlow { - stages: vec![AuthType::Password], - }], - completed: Vec::new(), - params: Box::default(), - session: None, - auth_error: None, - }; + // UIAA + let mut uiaainfo = UiaaInfo { + flows: vec![AuthFlow { + stages: vec![AuthType::Password], + }], + completed: Vec::new(), + params: Box::default(), + session: None, + auth_error: None, + }; - if let Some(auth) = &body.auth { - let (worked, uiaainfo) = - services() - .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; - if !worked { - return Err(Error::Uiaa(uiaainfo)); - } - // Success! - } else if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() - .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + if let Some(auth) = &body.auth { + let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; + if !worked { + return Err(Error::Uiaa(uiaainfo)); + } + // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; + return Err(Error::Uiaa(uiaainfo)); + } else { + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); + } - services() - .users - .remove_device(sender_user, &body.device_id)?; + services().users.remove_device(sender_user, &body.device_id)?; - Ok(delete_device::v3::Response {}) + Ok(delete_device::v3::Response {}) } /// # `PUT /_matrix/client/r0/devices/{deviceId}` @@ -122,48 +110,42 @@ pub async fn delete_device_route( /// /// For each device: /// - Invalidates access token -/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts) +/// - Deletes device metadata (device id, device display name, last seen ip, +/// last seen ts) /// - Forgets to-device events /// - Triggers device list updates -pub async fn delete_devices_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); +pub async fn delete_devices_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - // UIAA - let mut uiaainfo = UiaaInfo { - flows: vec![AuthFlow { - stages: vec![AuthType::Password], - }], - completed: Vec::new(), - params: Box::default(), - session: None, - auth_error: None, - }; + // UIAA + let mut uiaainfo = UiaaInfo { + flows: vec![AuthFlow { + stages: vec![AuthType::Password], + }], + completed: Vec::new(), + params: Box::default(), + session: None, + auth_error: None, + }; - if let Some(auth) = &body.auth { - let (worked, uiaainfo) = - services() - .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; - if !worked { - return Err(Error::Uiaa(uiaainfo)); - } - // Success! - } else if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() - .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + if let Some(auth) = &body.auth { + let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; + if !worked { + return Err(Error::Uiaa(uiaainfo)); + } + // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; + return Err(Error::Uiaa(uiaainfo)); + } else { + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); + } - for device_id in &body.devices { - services().users.remove_device(sender_user, device_id)?; - } + for device_id in &body.devices { + services().users.remove_device(sender_user, device_id)?; + } - Ok(delete_devices::v3::Response {}) + Ok(delete_devices::v3::Response {}) } diff --git a/src/api/client_server/directory.rs b/src/api/client_server/directory.rs index f0b0e25f..d328e0a6 100644 --- a/src/api/client_server/directory.rs +++ b/src/api/client_server/directory.rs @@ -1,57 +1,51 @@ -use crate::{services, Error, Result, Ruma}; use ruma::{ - api::{ - client::{ - directory::{ - get_public_rooms, get_public_rooms_filtered, get_room_visibility, - set_room_visibility, - }, - error::ErrorKind, - room, - }, - federation, - }, - directory::{Filter, PublicRoomJoinRule, PublicRoomsChunk, RoomNetwork}, - events::{ - room::{ - avatar::RoomAvatarEventContent, - canonical_alias::RoomCanonicalAliasEventContent, - create::RoomCreateEventContent, - guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - join_rules::{JoinRule, RoomJoinRulesEventContent}, - topic::RoomTopicEventContent, - }, - StateEventType, - }, - ServerName, UInt, + api::{ + client::{ + directory::{get_public_rooms, get_public_rooms_filtered, get_room_visibility, set_room_visibility}, + error::ErrorKind, + room, + }, + federation, + }, + directory::{Filter, PublicRoomJoinRule, PublicRoomsChunk, RoomNetwork}, + events::{ + room::{ + avatar::RoomAvatarEventContent, + canonical_alias::RoomCanonicalAliasEventContent, + create::RoomCreateEventContent, + guest_access::{GuestAccess, RoomGuestAccessEventContent}, + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + join_rules::{JoinRule, RoomJoinRulesEventContent}, + topic::RoomTopicEventContent, + }, + StateEventType, + }, + ServerName, UInt, }; use tracing::{error, info, warn}; +use crate::{services, Error, Result, Ruma}; + /// # `POST /_matrix/client/v3/publicRooms` /// /// Lists the public rooms on this server. /// /// - Rooms are ordered by the number of joined members pub async fn get_public_rooms_filtered_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services() - .globals - .config - .allow_public_room_directory_without_auth - { - let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); - } + if !services().globals.config.allow_public_room_directory_without_auth { + let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); + } - get_public_rooms_filtered_helper( - body.server.as_deref(), - body.limit, - body.since.as_deref(), - &body.filter, - &body.room_network, - ) - .await + get_public_rooms_filtered_helper( + body.server.as_deref(), + body.limit, + body.since.as_deref(), + &body.filter, + &body.room_network, + ) + .await } /// # `GET /_matrix/client/v3/publicRooms` @@ -60,31 +54,27 @@ pub async fn get_public_rooms_filtered_route( /// /// - Rooms are ordered by the number of joined members pub async fn get_public_rooms_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services() - .globals - .config - .allow_public_room_directory_without_auth - { - let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); - } + if !services().globals.config.allow_public_room_directory_without_auth { + let _sender_user = body.sender_user.as_ref().expect("user is authenticated"); + } - let response = get_public_rooms_filtered_helper( - body.server.as_deref(), - body.limit, - body.since.as_deref(), - &Filter::default(), - &RoomNetwork::Matrix, - ) - .await?; + let response = get_public_rooms_filtered_helper( + body.server.as_deref(), + body.limit, + body.since.as_deref(), + &Filter::default(), + &RoomNetwork::Matrix, + ) + .await?; - Ok(get_public_rooms::v3::Response { - chunk: response.chunk, - prev_batch: response.prev_batch, - next_batch: response.next_batch, - total_room_count_estimate: response.total_room_count_estimate, - }) + Ok(get_public_rooms::v3::Response { + chunk: response.chunk, + prev_batch: response.prev_batch, + next_batch: response.next_batch, + total_room_count_estimate: response.total_room_count_estimate, + }) } /// # `PUT /_matrix/client/r0/directory/list/room/{roomId}` @@ -93,294 +83,261 @@ pub async fn get_public_rooms_route( /// /// - TODO: Access control checks pub async fn set_room_visibility_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().rooms.metadata.exists(&body.room_id)? { - // Return 404 if the room doesn't exist - return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); - } + if !services().rooms.metadata.exists(&body.room_id)? { + // Return 404 if the room doesn't exist + return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); + } - match &body.visibility { - room::Visibility::Public => { - services().rooms.directory.set_public(&body.room_id)?; - info!("{} made {} public", sender_user, body.room_id); - } - room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?, - _ => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room visibility type is not supported.", - )); - } - } + match &body.visibility { + room::Visibility::Public => { + services().rooms.directory.set_public(&body.room_id)?; + info!("{} made {} public", sender_user, body.room_id); + }, + room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?, + _ => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room visibility type is not supported.", + )); + }, + } - Ok(set_room_visibility::v3::Response {}) + Ok(set_room_visibility::v3::Response {}) } /// # `GET /_matrix/client/r0/directory/list/room/{roomId}` /// /// Gets the visibility of a given room in the room directory. pub async fn get_room_visibility_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services().rooms.metadata.exists(&body.room_id)? { - // Return 404 if the room doesn't exist - return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); - } + if !services().rooms.metadata.exists(&body.room_id)? { + // Return 404 if the room doesn't exist + return Err(Error::BadRequest(ErrorKind::NotFound, "Room not found")); + } - Ok(get_room_visibility::v3::Response { - visibility: if services().rooms.directory.is_public_room(&body.room_id)? { - room::Visibility::Public - } else { - room::Visibility::Private - }, - }) + Ok(get_room_visibility::v3::Response { + visibility: if services().rooms.directory.is_public_room(&body.room_id)? { + room::Visibility::Public + } else { + room::Visibility::Private + }, + }) } pub(crate) async fn get_public_rooms_filtered_helper( - server: Option<&ServerName>, - limit: Option, - since: Option<&str>, - filter: &Filter, - _network: &RoomNetwork, + server: Option<&ServerName>, limit: Option, since: Option<&str>, filter: &Filter, _network: &RoomNetwork, ) -> Result { - if let Some(other_server) = - server.filter(|server| *server != services().globals.server_name().as_str()) - { - let response = services() - .sending - .send_federation_request( - other_server, - federation::directory::get_public_rooms_filtered::v1::Request { - limit, - since: since.map(ToOwned::to_owned), - filter: Filter { - generic_search_term: filter.generic_search_term.clone(), - room_types: filter.room_types.clone(), - }, - room_network: RoomNetwork::Matrix, - }, - ) - .await?; + if let Some(other_server) = server.filter(|server| *server != services().globals.server_name().as_str()) { + let response = services() + .sending + .send_federation_request( + other_server, + federation::directory::get_public_rooms_filtered::v1::Request { + limit, + since: since.map(ToOwned::to_owned), + filter: Filter { + generic_search_term: filter.generic_search_term.clone(), + room_types: filter.room_types.clone(), + }, + room_network: RoomNetwork::Matrix, + }, + ) + .await?; - return Ok(get_public_rooms_filtered::v3::Response { - chunk: response.chunk, - prev_batch: response.prev_batch, - next_batch: response.next_batch, - total_room_count_estimate: response.total_room_count_estimate, - }); - } + return Ok(get_public_rooms_filtered::v3::Response { + chunk: response.chunk, + prev_batch: response.prev_batch, + next_batch: response.next_batch, + total_room_count_estimate: response.total_room_count_estimate, + }); + } - let limit = limit.map_or(10, u64::from); - let mut num_since = 0_u64; + let limit = limit.map_or(10, u64::from); + let mut num_since = 0_u64; - if let Some(s) = &since { - let mut characters = s.chars(); - let backwards = match characters.next() { - Some('n') => false, - Some('p') => true, - _ => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid `since` token", - )) - } - }; + if let Some(s) = &since { + let mut characters = s.chars(); + let backwards = match characters.next() { + Some('n') => false, + Some('p') => true, + _ => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token")), + }; - num_since = characters - .collect::() - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token."))?; + num_since = characters + .collect::() + .parse() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token."))?; - if backwards { - num_since = num_since.saturating_sub(limit); - } - } + if backwards { + num_since = num_since.saturating_sub(limit); + } + } - let mut all_rooms: Vec<_> = services() - .rooms - .directory - .public_rooms() - .map(|room_id| { - let room_id = room_id?; + let mut all_rooms: Vec<_> = services() + .rooms + .directory + .public_rooms() + .map(|room_id| { + let room_id = room_id?; - let chunk = PublicRoomsChunk { - canonical_alias: services() - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomCanonicalAliasEventContent| c.alias) - .map_err(|_| { - Error::bad_database("Invalid canonical alias event in database.") - }) - })?, - name: services().rooms.state_accessor.get_name(&room_id)?, - num_joined_members: services() - .rooms - .state_cache - .room_joined_count(&room_id)? - .unwrap_or_else(|| { - warn!("Room {} has no member count", room_id); - 0 - }) - .try_into() - .expect("user count should not be that big"), - topic: services() - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomTopic, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomTopicEventContent| Some(c.topic)) - .map_err(|_| { - error!("Invalid room topic event in database for room {}", room_id); - Error::bad_database("Invalid room topic event in database.") - }) - }) - .unwrap_or(None), - world_readable: services() - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| { - c.history_visibility == HistoryVisibility::WorldReadable - }) - .map_err(|_| { - Error::bad_database( - "Invalid room history visibility event in database.", - ) - }) - })?, - guest_can_join: services() - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomGuestAccessEventContent| { - c.guest_access == GuestAccess::CanJoin - }) - .map_err(|_| { - Error::bad_database("Invalid room guest access event in database.") - }) - })?, - avatar_url: services() - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomAvatar, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomAvatarEventContent| c.url) - .map_err(|_| { - Error::bad_database("Invalid room avatar event in database.") - }) - }) - .transpose()? - // url is now an Option so we must flatten - .flatten(), - join_rule: services() - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| match c.join_rule { - JoinRule::Public => Some(PublicRoomJoinRule::Public), - JoinRule::Knock => Some(PublicRoomJoinRule::Knock), - _ => None, - }) - .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") - }) - }) - .transpose()? - .flatten() - .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, - room_type: services() - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomCreate, "")? - .map(|s| { - serde_json::from_str::(s.content.get()).map_err( - |e| { - error!("Invalid room create event in database: {}", e); - Error::BadDatabase("Invalid room create event in database.") - }, - ) - }) - .transpose()? - .and_then(|e| e.room_type), - room_id, - }; - Ok(chunk) - }) - .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms - .filter(|chunk| { - if let Some(query) = filter - .generic_search_term - .as_ref() - .map(|q| q.to_lowercase()) - { - if let Some(name) = &chunk.name { - if name.as_str().to_lowercase().contains(&query) { - return true; - } - } + let chunk = PublicRoomsChunk { + canonical_alias: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomCanonicalAlias, "")? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomCanonicalAliasEventContent| c.alias) + .map_err(|_| Error::bad_database("Invalid canonical alias event in database.")) + })?, + name: services().rooms.state_accessor.get_name(&room_id)?, + num_joined_members: services() + .rooms + .state_cache + .room_joined_count(&room_id)? + .unwrap_or_else(|| { + warn!("Room {} has no member count", room_id); + 0 + }) + .try_into() + .expect("user count should not be that big"), + topic: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomTopic, "")? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomTopicEventContent| Some(c.topic)) + .map_err(|_| { + error!("Invalid room topic event in database for room {}", room_id); + Error::bad_database("Invalid room topic event in database.") + }) + }) + .unwrap_or(None), + world_readable: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomHistoryVisibility, "")? + .map_or(Ok(false), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomHistoryVisibilityEventContent| { + c.history_visibility == HistoryVisibility::WorldReadable + }) + .map_err(|_| Error::bad_database("Invalid room history visibility event in database.")) + })?, + guest_can_join: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")? + .map_or(Ok(false), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) + .map_err(|_| Error::bad_database("Invalid room guest access event in database.")) + })?, + avatar_url: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomAvatar, "")? + .map(|s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomAvatarEventContent| c.url) + .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) + }) + .transpose()? + // url is now an Option so we must flatten + .flatten(), + join_rule: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomJoinRules, "")? + .map(|s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomJoinRulesEventContent| match c.join_rule { + JoinRule::Public => Some(PublicRoomJoinRule::Public), + JoinRule::Knock => Some(PublicRoomJoinRule::Knock), + _ => None, + }) + .map_err(|e| { + error!("Invalid room join rule event in database: {}", e); + Error::BadDatabase("Invalid room join rule event in database.") + }) + }) + .transpose()? + .flatten() + .ok_or_else(|| Error::bad_database("Missing room join rule event for room."))?, + room_type: services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomCreate, "")? + .map(|s| { + serde_json::from_str::(s.content.get()).map_err(|e| { + error!("Invalid room create event in database: {}", e); + Error::BadDatabase("Invalid room create event in database.") + }) + }) + .transpose()? + .and_then(|e| e.room_type), + room_id, + }; + Ok(chunk) + }) + .filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms + .filter(|chunk| { + if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) { + if let Some(name) = &chunk.name { + if name.as_str().to_lowercase().contains(&query) { + return true; + } + } - if let Some(topic) = &chunk.topic { - if topic.to_lowercase().contains(&query) { - return true; - } - } + if let Some(topic) = &chunk.topic { + if topic.to_lowercase().contains(&query) { + return true; + } + } - if let Some(canonical_alias) = &chunk.canonical_alias { - if canonical_alias.as_str().to_lowercase().contains(&query) { - return true; - } - } + if let Some(canonical_alias) = &chunk.canonical_alias { + if canonical_alias.as_str().to_lowercase().contains(&query) { + return true; + } + } - false - } else { - // No search term - true - } - }) - // We need to collect all, so we can sort by member count - .collect(); + false + } else { + // No search term + true + } + }) + // We need to collect all, so we can sort by member count + .collect(); - all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); + all_rooms.sort_by(|l, r| r.num_joined_members.cmp(&l.num_joined_members)); - let total_room_count_estimate = (all_rooms.len() as u32).into(); + let total_room_count_estimate = (all_rooms.len() as u32).into(); - let chunk: Vec<_> = all_rooms - .into_iter() - .skip(num_since as usize) - .take(limit as usize) - .collect(); + let chunk: Vec<_> = all_rooms.into_iter().skip(num_since as usize).take(limit as usize).collect(); - let prev_batch = if num_since == 0 { - None - } else { - Some(format!("p{num_since}")) - }; + let prev_batch = if num_since == 0 { + None + } else { + Some(format!("p{num_since}")) + }; - let next_batch = if chunk.len() < limit as usize { - None - } else { - Some(format!("n{}", num_since + limit)) - }; + let next_batch = if chunk.len() < limit as usize { + None + } else { + Some(format!("n{}", num_since + limit)) + }; - Ok(get_public_rooms_filtered::v3::Response { - chunk, - prev_batch, - next_batch, - total_room_count_estimate: Some(total_room_count_estimate), - }) + Ok(get_public_rooms_filtered::v3::Response { + chunk, + prev_batch, + next_batch, + total_room_count_estimate: Some(total_room_count_estimate), + }) } diff --git a/src/api/client_server/filter.rs b/src/api/client_server/filter.rs index e9a359d6..9e69f7c5 100644 --- a/src/api/client_server/filter.rs +++ b/src/api/client_server/filter.rs @@ -1,34 +1,31 @@ -use crate::{services, Error, Result, Ruma}; use ruma::api::client::{ - error::ErrorKind, - filter::{create_filter, get_filter}, + error::ErrorKind, + filter::{create_filter, get_filter}, }; +use crate::{services, Error, Result, Ruma}; + /// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}` /// /// Loads a filter that was previously created. /// /// - A user can only access their own filters -pub async fn get_filter_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let filter = match services().users.get_filter(sender_user, &body.filter_id)? { - Some(filter) => filter, - None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")), - }; +pub async fn get_filter_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let filter = match services().users.get_filter(sender_user, &body.filter_id)? { + Some(filter) => filter, + None => return Err(Error::BadRequest(ErrorKind::NotFound, "Filter not found.")), + }; - Ok(get_filter::v3::Response::new(filter)) + Ok(get_filter::v3::Response::new(filter)) } /// # `PUT /_matrix/client/r0/user/{userId}/filter` /// /// Creates a new filter to be used by other endpoints. -pub async fn create_filter_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - Ok(create_filter::v3::Response::new( - services().users.create_filter(sender_user, &body.filter)?, - )) +pub async fn create_filter_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + Ok(create_filter::v3::Response::new( + services().users.create_filter(sender_user, &body.filter)?, + )) } diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index b57e28a8..e32d7a97 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -1,65 +1,53 @@ -use super::SESSION_ID_LENGTH; -use crate::{services, utils, Error, Result, Ruma}; +use std::{ + collections::{hash_map, BTreeMap, HashMap, HashSet}, + time::{Duration, Instant}, +}; + use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ - api::{ - client::{ - error::ErrorKind, - keys::{ - claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, - upload_signing_keys, - }, - uiaa::{AuthFlow, AuthType, UiaaInfo}, - }, - federation, - }, - serde::Raw, - DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, + api::{ + client::{ + error::ErrorKind, + keys::{claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, upload_signing_keys}, + uiaa::{AuthFlow, AuthType, UiaaInfo}, + }, + federation, + }, + serde::Raw, + DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId, }; use serde_json::json; -use std::{ - collections::{hash_map, BTreeMap, HashMap, HashSet}, - time::{Duration, Instant}, -}; use tracing::{debug, error}; +use super::SESSION_ID_LENGTH; +use crate::{services, utils, Error, Result, Ruma}; + /// # `POST /_matrix/client/r0/keys/upload` /// /// Publish end-to-end encryption keys for the sender device. /// /// - Adds one time keys -/// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?) -pub async fn upload_keys_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); +/// - If there are no device keys yet: Adds device keys (TODO: merge with +/// existing keys?) +pub async fn upload_keys_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - for (key_key, key_value) in &body.one_time_keys { - services() - .users - .add_one_time_key(sender_user, sender_device, key_key, key_value)?; - } + for (key_key, key_value) in &body.one_time_keys { + services().users.add_one_time_key(sender_user, sender_device, key_key, key_value)?; + } - if let Some(device_keys) = &body.device_keys { - // TODO: merge this and the existing event? - // This check is needed to assure that signatures are kept - if services() - .users - .get_device_keys(sender_user, sender_device)? - .is_none() - { - services() - .users - .add_device_keys(sender_user, sender_device, device_keys)?; - } - } + if let Some(device_keys) = &body.device_keys { + // TODO: merge this and the existing event? + // This check is needed to assure that signatures are kept + if services().users.get_device_keys(sender_user, sender_device)?.is_none() { + services().users.add_device_keys(sender_user, sender_device, device_keys)?; + } + } - Ok(upload_keys::v3::Response { - one_time_key_counts: services() - .users - .count_one_time_keys(sender_user, sender_device)?, - }) + Ok(upload_keys::v3::Response { + one_time_key_counts: services().users.count_one_time_keys(sender_user, sender_device)?, + }) } /// # `POST /_matrix/client/r0/keys/query` @@ -68,30 +56,29 @@ pub async fn upload_keys_route( /// /// - Always fetches users from other servers over federation /// - Gets master keys, self-signing keys, user signing keys and device keys. -/// - The master and self-signing keys contain signatures that the user is allowed to see +/// - The master and self-signing keys contain signatures that the user is +/// allowed to see pub async fn get_keys_route(body: Ruma) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let response = get_keys_helper( - Some(sender_user), - &body.device_keys, - |u| u == sender_user, - true, // Always allow local users to see device names of other local users - ) - .await?; + let response = get_keys_helper( + Some(sender_user), + &body.device_keys, + |u| u == sender_user, + true, // Always allow local users to see device names of other local users + ) + .await?; - Ok(response) + Ok(response) } /// # `POST /_matrix/client/r0/keys/claim` /// /// Claims one-time keys -pub async fn claim_keys_route( - body: Ruma, -) -> Result { - let response = claim_keys_helper(&body.one_time_keys).await?; +pub async fn claim_keys_route(body: Ruma) -> Result { + let response = claim_keys_helper(&body.one_time_keys).await?; - Ok(response) + Ok(response) } /// # `POST /_matrix/client/r0/keys/device_signing/upload` @@ -100,452 +87,373 @@ pub async fn claim_keys_route( /// /// - Requires UIAA to verify password pub async fn upload_signing_keys_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - // UIAA - let mut uiaainfo = UiaaInfo { - flows: vec![AuthFlow { - stages: vec![AuthType::Password], - }], - completed: Vec::new(), - params: Box::default(), - session: None, - auth_error: None, - }; + // UIAA + let mut uiaainfo = UiaaInfo { + flows: vec![AuthFlow { + stages: vec![AuthType::Password], + }], + completed: Vec::new(), + params: Box::default(), + session: None, + auth_error: None, + }; - if let Some(auth) = &body.auth { - let (worked, uiaainfo) = - services() - .uiaa - .try_auth(sender_user, sender_device, auth, &uiaainfo)?; - if !worked { - return Err(Error::Uiaa(uiaainfo)); - } - // Success! - } else if let Some(json) = body.json_body { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - services() - .uiaa - .create(sender_user, sender_device, &uiaainfo, &json)?; - return Err(Error::Uiaa(uiaainfo)); - } else { - return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); - } + if let Some(auth) = &body.auth { + let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?; + if !worked { + return Err(Error::Uiaa(uiaainfo)); + } + // Success! + } else if let Some(json) = body.json_body { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?; + return Err(Error::Uiaa(uiaainfo)); + } else { + return Err(Error::BadRequest(ErrorKind::NotJson, "Not json.")); + } - if let Some(master_key) = &body.master_key { - services().users.add_cross_signing_keys( - sender_user, - master_key, - &body.self_signing_key, - &body.user_signing_key, - true, // notify so that other users see the new keys - )?; - } + if let Some(master_key) = &body.master_key { + services().users.add_cross_signing_keys( + sender_user, + master_key, + &body.self_signing_key, + &body.user_signing_key, + true, // notify so that other users see the new keys + )?; + } - Ok(upload_signing_keys::v3::Response {}) + Ok(upload_signing_keys::v3::Response {}) } /// # `POST /_matrix/client/r0/keys/signatures/upload` /// /// Uploads end-to-end key signatures from the sender user. pub async fn upload_signatures_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for (user_id, keys) in &body.signed_keys { - for (key_id, key) in keys { - let key = serde_json::to_value(key) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON"))?; + for (user_id, keys) in &body.signed_keys { + for (key_id, key) in keys { + let key = serde_json::to_value(key) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid key JSON"))?; - for signature in key - .get("signatures") - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Missing signatures field.", - ))? - .get(sender_user.to_string()) - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid user in signatures field.", - ))? - .as_object() - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid signature.", - ))? - .clone() - .into_iter() - { - // Signature validation? - let signature = ( - signature.0, - signature - .1 - .as_str() - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid signature value.", - ))? - .to_owned(), - ); - services() - .users - .sign_key(user_id, key_id, signature, sender_user)?; - } - } - } + for signature in key + .get("signatures") + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Missing signatures field."))? + .get(sender_user.to_string()) + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid user in signatures field."))? + .as_object() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature."))? + .clone() + .into_iter() + { + // Signature validation? + let signature = ( + signature.0, + signature + .1 + .as_str() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))? + .to_owned(), + ); + services().users.sign_key(user_id, key_id, signature, sender_user)?; + } + } + } - Ok(upload_signatures::v3::Response { - failures: BTreeMap::new(), // TODO: integrate - }) + Ok(upload_signatures::v3::Response { + failures: BTreeMap::new(), // TODO: integrate + }) } /// # `POST /_matrix/client/r0/keys/changes` /// -/// Gets a list of users who have updated their device identity keys since the previous sync token. +/// Gets a list of users who have updated their device identity keys since the +/// previous sync token. /// /// - TODO: left users -pub async fn get_key_changes_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn get_key_changes_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut device_list_updates = HashSet::new(); + let mut device_list_updates = HashSet::new(); - device_list_updates.extend( - services() - .users - .keys_changed( - sender_user.as_str(), - body.from - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, - Some( - body.to - .parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?, - ), - ) - .filter_map(std::result::Result::ok), - ); + device_list_updates.extend( + services() + .users + .keys_changed( + sender_user.as_str(), + body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, + Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?), + ) + .filter_map(std::result::Result::ok), + ); - for room_id in services() - .rooms - .state_cache - .rooms_joined(sender_user) - .filter_map(std::result::Result::ok) - { - device_list_updates.extend( - services() - .users - .keys_changed( - room_id.as_ref(), - body.from.parse().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`.") - })?, - Some(body.to.parse().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.") - })?), - ) - .filter_map(std::result::Result::ok), - ); - } - Ok(get_key_changes::v3::Response { - changed: device_list_updates.into_iter().collect(), - left: Vec::new(), // TODO - }) + for room_id in services().rooms.state_cache.rooms_joined(sender_user).filter_map(std::result::Result::ok) { + device_list_updates.extend( + services() + .users + .keys_changed( + room_id.as_ref(), + body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?, + Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?), + ) + .filter_map(std::result::Result::ok), + ); + } + Ok(get_key_changes::v3::Response { + changed: device_list_updates.into_iter().collect(), + left: Vec::new(), // TODO + }) } pub(crate) async fn get_keys_helper bool>( - sender_user: Option<&UserId>, - device_keys_input: &BTreeMap>, - allowed_signatures: F, - include_display_names: bool, + sender_user: Option<&UserId>, device_keys_input: &BTreeMap>, allowed_signatures: F, + include_display_names: bool, ) -> Result { - let mut master_keys = BTreeMap::new(); - let mut self_signing_keys = BTreeMap::new(); - let mut user_signing_keys = BTreeMap::new(); - let mut device_keys = BTreeMap::new(); + let mut master_keys = BTreeMap::new(); + let mut self_signing_keys = BTreeMap::new(); + let mut user_signing_keys = BTreeMap::new(); + let mut device_keys = BTreeMap::new(); - let mut get_over_federation = HashMap::new(); + let mut get_over_federation = HashMap::new(); - for (user_id, device_ids) in device_keys_input { - let user_id: &UserId = user_id; + for (user_id, device_ids) in device_keys_input { + let user_id: &UserId = user_id; - if user_id.server_name() != services().globals.server_name() { - get_over_federation - .entry(user_id.server_name()) - .or_insert_with(Vec::new) - .push((user_id, device_ids)); - continue; - } + if user_id.server_name() != services().globals.server_name() { + get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, device_ids)); + continue; + } - if device_ids.is_empty() { - let mut container = BTreeMap::new(); - for device_id in services().users.all_device_ids(user_id) { - let device_id = device_id?; - if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? { - let metadata = services() - .users - .get_device_metadata(user_id, &device_id)? - .ok_or_else(|| { - Error::bad_database("all_device_keys contained nonexistent device.") - })?; + if device_ids.is_empty() { + let mut container = BTreeMap::new(); + for device_id in services().users.all_device_ids(user_id) { + let device_id = device_id?; + if let Some(mut keys) = services().users.get_device_keys(user_id, &device_id)? { + let metadata = services() + .users + .get_device_metadata(user_id, &device_id)? + .ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?; - add_unsigned_device_display_name(&mut keys, metadata, include_display_names) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; + add_unsigned_device_display_name(&mut keys, metadata, include_display_names) + .map_err(|_| Error::bad_database("invalid device keys in database"))?; - container.insert(device_id, keys); - } - } - device_keys.insert(user_id.to_owned(), container); - } else { - for device_id in device_ids { - let mut container = BTreeMap::new(); - if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { - let metadata = services() - .users - .get_device_metadata(user_id, device_id)? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to get keys for nonexistent device.", - ))?; + container.insert(device_id, keys); + } + } + device_keys.insert(user_id.to_owned(), container); + } else { + for device_id in device_ids { + let mut container = BTreeMap::new(); + if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? { + let metadata = services().users.get_device_metadata(user_id, device_id)?.ok_or( + Error::BadRequest(ErrorKind::InvalidParam, "Tried to get keys for nonexistent device."), + )?; - add_unsigned_device_display_name(&mut keys, metadata, include_display_names) - .map_err(|_| Error::bad_database("invalid device keys in database"))?; - container.insert(device_id.to_owned(), keys); - } - device_keys.insert(user_id.to_owned(), container); - } - } + add_unsigned_device_display_name(&mut keys, metadata, include_display_names) + .map_err(|_| Error::bad_database("invalid device keys in database"))?; + container.insert(device_id.to_owned(), keys); + } + device_keys.insert(user_id.to_owned(), container); + } + } - if let Some(master_key) = - services() - .users - .get_master_key(sender_user, user_id, &allowed_signatures)? - { - master_keys.insert(user_id.to_owned(), master_key); - } - if let Some(self_signing_key) = - services() - .users - .get_self_signing_key(sender_user, user_id, &allowed_signatures)? - { - self_signing_keys.insert(user_id.to_owned(), self_signing_key); - } - if Some(user_id) == sender_user { - if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? { - user_signing_keys.insert(user_id.to_owned(), user_signing_key); - } - } - } + if let Some(master_key) = services().users.get_master_key(sender_user, user_id, &allowed_signatures)? { + master_keys.insert(user_id.to_owned(), master_key); + } + if let Some(self_signing_key) = + services().users.get_self_signing_key(sender_user, user_id, &allowed_signatures)? + { + self_signing_keys.insert(user_id.to_owned(), self_signing_key); + } + if Some(user_id) == sender_user { + if let Some(user_signing_key) = services().users.get_user_signing_key(user_id)? { + user_signing_keys.insert(user_id.to_owned(), user_signing_key); + } + } + } - let mut failures = BTreeMap::new(); + let mut failures = BTreeMap::new(); - let back_off = |id| match services() - .globals - .bad_query_ratelimiter - .write() - .unwrap() - .entry(id) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - }; + let back_off = |id| match services().globals.bad_query_ratelimiter.write().unwrap().entry(id) { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + }; - let mut futures: FuturesUnordered<_> = get_over_federation - .into_iter() - .map(|(server, vec)| async move { - if let Some((time, tries)) = services() - .globals - .bad_query_ratelimiter - .read() - .unwrap() - .get(server) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } + let mut futures: FuturesUnordered<_> = get_over_federation + .into_iter() + .map(|(server, vec)| async move { + if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().unwrap().get(server) { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } - if time.elapsed() < min_elapsed_duration { - debug!("Backing off query from {:?}", server); - return ( - server, - Err(Error::BadServerResponse("bad query, still backing off")), - ); - } - } + if time.elapsed() < min_elapsed_duration { + debug!("Backing off query from {:?}", server); + return (server, Err(Error::BadServerResponse("bad query, still backing off"))); + } + } - let mut device_keys_input_fed = BTreeMap::new(); - for (user_id, keys) in vec { - device_keys_input_fed.insert(user_id.to_owned(), keys.clone()); - } - ( - server, - tokio::time::timeout( - Duration::from_secs(50), - services().sending.send_federation_request( - server, - federation::keys::get_keys::v1::Request { - device_keys: device_keys_input_fed, - }, - ), - ) - .await - .map_err(|e| { - error!("get_keys_helper query took too long: {}", e); - Error::BadServerResponse("get_keys_helper query took too long") - }), - ) - }) - .collect(); + let mut device_keys_input_fed = BTreeMap::new(); + for (user_id, keys) in vec { + device_keys_input_fed.insert(user_id.to_owned(), keys.clone()); + } + ( + server, + tokio::time::timeout( + Duration::from_secs(50), + services().sending.send_federation_request( + server, + federation::keys::get_keys::v1::Request { + device_keys: device_keys_input_fed, + }, + ), + ) + .await + .map_err(|e| { + error!("get_keys_helper query took too long: {}", e); + Error::BadServerResponse("get_keys_helper query took too long") + }), + ) + }) + .collect(); - while let Some((server, response)) = futures.next().await { - match response { - Ok(Ok(response)) => { - for (user, masterkey) in response.master_keys { - let (master_key_id, mut master_key) = - services().users.parse_master_key(&user, &masterkey)?; + while let Some((server, response)) = futures.next().await { + match response { + Ok(Ok(response)) => { + for (user, masterkey) in response.master_keys { + let (master_key_id, mut master_key) = services().users.parse_master_key(&user, &masterkey)?; - if let Some(our_master_key) = services().users.get_key( - &master_key_id, - sender_user, - &user, - &allowed_signatures, - )? { - let (_, our_master_key) = - services().users.parse_master_key(&user, &our_master_key)?; - master_key.signatures.extend(our_master_key.signatures); - } - let json = serde_json::to_value(master_key).expect("to_value always works"); - let raw = serde_json::from_value(json).expect("Raw::from_value always works"); - services().users.add_cross_signing_keys( - &user, &raw, &None, &None, - false, // Dont notify. A notification would trigger another key request resulting in an endless loop - )?; - master_keys.insert(user, raw); - } + if let Some(our_master_key) = + services().users.get_key(&master_key_id, sender_user, &user, &allowed_signatures)? + { + let (_, our_master_key) = services().users.parse_master_key(&user, &our_master_key)?; + master_key.signatures.extend(our_master_key.signatures); + } + let json = serde_json::to_value(master_key).expect("to_value always works"); + let raw = serde_json::from_value(json).expect("Raw::from_value always works"); + services().users.add_cross_signing_keys( + &user, &raw, &None, &None, + false, /* Dont notify. A notification would trigger another key request resulting in an + * endless loop */ + )?; + master_keys.insert(user, raw); + } - self_signing_keys.extend(response.self_signing_keys); - device_keys.extend(response.device_keys); - } - _ => { - back_off(server.to_owned()); - failures.insert(server.to_string(), json!({})); - } - } - } + self_signing_keys.extend(response.self_signing_keys); + device_keys.extend(response.device_keys); + }, + _ => { + back_off(server.to_owned()); + failures.insert(server.to_string(), json!({})); + }, + } + } - Ok(get_keys::v3::Response { - master_keys, - self_signing_keys, - user_signing_keys, - device_keys, - failures, - }) + Ok(get_keys::v3::Response { + master_keys, + self_signing_keys, + user_signing_keys, + device_keys, + failures, + }) } fn add_unsigned_device_display_name( - keys: &mut Raw, - metadata: ruma::api::client::device::Device, - include_display_names: bool, + keys: &mut Raw, metadata: ruma::api::client::device::Device, + include_display_names: bool, ) -> serde_json::Result<()> { - if let Some(display_name) = metadata.display_name { - let mut object = keys.deserialize_as::>()?; + if let Some(display_name) = metadata.display_name { + let mut object = keys.deserialize_as::>()?; - let unsigned = object.entry("unsigned").or_insert_with(|| json!({})); - if let serde_json::Value::Object(unsigned_object) = unsigned { - if include_display_names { - unsigned_object.insert("device_display_name".to_owned(), display_name.into()); - } else { - unsigned_object.insert( - "device_display_name".to_owned(), - Some(metadata.device_id.as_str().to_owned()).into(), - ); - } - } + let unsigned = object.entry("unsigned").or_insert_with(|| json!({})); + if let serde_json::Value::Object(unsigned_object) = unsigned { + if include_display_names { + unsigned_object.insert("device_display_name".to_owned(), display_name.into()); + } else { + unsigned_object.insert( + "device_display_name".to_owned(), + Some(metadata.device_id.as_str().to_owned()).into(), + ); + } + } - *keys = Raw::from_json(serde_json::value::to_raw_value(&object)?); - } + *keys = Raw::from_json(serde_json::value::to_raw_value(&object)?); + } - Ok(()) + Ok(()) } pub(crate) async fn claim_keys_helper( - one_time_keys_input: &BTreeMap>, + one_time_keys_input: &BTreeMap>, ) -> Result { - let mut one_time_keys = BTreeMap::new(); + let mut one_time_keys = BTreeMap::new(); - let mut get_over_federation = BTreeMap::new(); + let mut get_over_federation = BTreeMap::new(); - for (user_id, map) in one_time_keys_input { - if user_id.server_name() != services().globals.server_name() { - get_over_federation - .entry(user_id.server_name()) - .or_insert_with(Vec::new) - .push((user_id, map)); - } + for (user_id, map) in one_time_keys_input { + if user_id.server_name() != services().globals.server_name() { + get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, map)); + } - let mut container = BTreeMap::new(); - for (device_id, key_algorithm) in map { - if let Some(one_time_keys) = - services() - .users - .take_one_time_key(user_id, device_id, key_algorithm)? - { - let mut c = BTreeMap::new(); - c.insert(one_time_keys.0, one_time_keys.1); - container.insert(device_id.clone(), c); - } - } - one_time_keys.insert(user_id.clone(), container); - } + let mut container = BTreeMap::new(); + for (device_id, key_algorithm) in map { + if let Some(one_time_keys) = services().users.take_one_time_key(user_id, device_id, key_algorithm)? { + let mut c = BTreeMap::new(); + c.insert(one_time_keys.0, one_time_keys.1); + container.insert(device_id.clone(), c); + } + } + one_time_keys.insert(user_id.clone(), container); + } - let mut failures = BTreeMap::new(); + let mut failures = BTreeMap::new(); - let mut futures: FuturesUnordered<_> = get_over_federation - .into_iter() - .map(|(server, vec)| async move { - let mut one_time_keys_input_fed = BTreeMap::new(); - for (user_id, keys) in vec { - one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); - } - ( - server, - services() - .sending - .send_federation_request( - server, - federation::keys::claim_keys::v1::Request { - one_time_keys: one_time_keys_input_fed, - }, - ) - .await, - ) - }) - .collect(); + let mut futures: FuturesUnordered<_> = get_over_federation + .into_iter() + .map(|(server, vec)| async move { + let mut one_time_keys_input_fed = BTreeMap::new(); + for (user_id, keys) in vec { + one_time_keys_input_fed.insert(user_id.clone(), keys.clone()); + } + ( + server, + services() + .sending + .send_federation_request( + server, + federation::keys::claim_keys::v1::Request { + one_time_keys: one_time_keys_input_fed, + }, + ) + .await, + ) + }) + .collect(); - while let Some((server, response)) = futures.next().await { - match response { - Ok(keys) => { - one_time_keys.extend(keys.one_time_keys); - } - Err(_e) => { - failures.insert(server.to_string(), json!({})); - } - } - } + while let Some((server, response)) = futures.next().await { + match response { + Ok(keys) => { + one_time_keys.extend(keys.one_time_keys); + }, + Err(_e) => { + failures.insert(server.to_string(), json!({})); + }, + } + } - Ok(claim_keys::v3::Response { - failures, - one_time_keys, - }) + Ok(claim_keys::v3::Response { + failures, + one_time_keys, + }) } diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index 20fabbcf..7d0cf7ab 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -1,22 +1,22 @@ use std::{io::Cursor, net::IpAddr, sync::Arc, time::Duration}; -use crate::{ - service::media::{FileMeta, UrlPreviewData}, - services, utils, Error, Result, Ruma, -}; use image::io::Reader as ImgReader; - use reqwest::Url; use ruma::api::client::{ - error::ErrorKind, - media::{ - create_content, get_content, get_content_as_filename, get_content_thumbnail, - get_media_config, get_media_preview, - }, + error::ErrorKind, + media::{ + create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config, + get_media_preview, + }, }; use tracing::{debug, error, info, warn}; use webpage::HTML; +use crate::{ + service::media::{FileMeta, UrlPreviewData}, + services, utils, Error, Result, Ruma, +}; + /// generated MXC ID (`media-id`) length const MXC_LENGTH: usize = 32; @@ -24,48 +24,39 @@ const MXC_LENGTH: usize = 32; /// /// Returns max upload size. pub async fn get_media_config_route( - _body: Ruma, + _body: Ruma, ) -> Result { - Ok(get_media_config::v3::Response { - upload_size: services().globals.max_request_size().into(), - }) + Ok(get_media_config::v3::Response { + upload_size: services().globals.max_request_size().into(), + }) } /// # `GET /_matrix/media/v3/preview_url` /// /// Returns URL preview. pub async fn get_media_preview_route( - body: Ruma, + body: Ruma, ) -> Result { - let url = &body.url; - if !url_preview_allowed(url) { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "URL is not allowed to be previewed", - )); - } + let url = &body.url; + if !url_preview_allowed(url) { + return Err(Error::BadRequest(ErrorKind::Forbidden, "URL is not allowed to be previewed")); + } - if let Ok(preview) = get_url_preview(url).await { - let res = serde_json::value::to_raw_value(&preview).map_err(|e| { - error!( - "Failed to convert UrlPreviewData into a serde json value: {}", - e - ); - Error::BadRequest( - ErrorKind::Unknown, - "Unknown error occurred parsing URL preview", - ) - })?; + if let Ok(preview) = get_url_preview(url).await { + let res = serde_json::value::to_raw_value(&preview).map_err(|e| { + error!("Failed to convert UrlPreviewData into a serde json value: {}", e); + Error::BadRequest(ErrorKind::Unknown, "Unknown error occurred parsing URL preview") + })?; - return Ok(get_media_preview::v3::Response::from_raw_value(res)); - } + return Ok(get_media_preview::v3::Response::from_raw_value(res)); + } - Err(Error::BadRequest( - ErrorKind::LimitExceeded { - retry_after_ms: Some(Duration::from_secs(5)), - }, - "Retry later", - )) + Err(Error::BadRequest( + ErrorKind::LimitExceeded { + retry_after_ms: Some(Duration::from_secs(5)), + }, + "Retry later", + )) } /// # `POST /_matrix/media/v3/upload` @@ -74,80 +65,70 @@ pub async fn get_media_preview_route( /// /// - Some metadata will be saved in the database /// - Media will be saved in the media/ directory -pub async fn create_content_route( - body: Ruma, -) -> Result { - let mxc = format!( - "mxc://{}/{}", - services().globals.server_name(), - utils::random_string(MXC_LENGTH) - ); +pub async fn create_content_route(body: Ruma) -> Result { + let mxc = format!( + "mxc://{}/{}", + services().globals.server_name(), + utils::random_string(MXC_LENGTH) + ); - services() - .media - .create( - mxc.clone(), - body.filename - .as_ref() - .map(|filename| "inline; filename=".to_owned() + filename) - .as_deref(), - body.content_type.as_deref(), - &body.file, - ) - .await?; + services() + .media + .create( + mxc.clone(), + body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(), + body.content_type.as_deref(), + &body.file, + ) + .await?; - let content_uri = mxc.into(); + let content_uri = mxc.into(); - Ok(create_content::v3::Response { - content_uri, - blurhash: None, - }) + Ok(create_content::v3::Response { + content_uri, + blurhash: None, + }) } /// helper method to fetch remote media from other servers over federation pub async fn get_remote_content( - mxc: &str, - server_name: &ruma::ServerName, - media_id: String, - allow_redirect: bool, - timeout_ms: Duration, + mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool, timeout_ms: Duration, ) -> Result { - // we'll lie to the client and say the blocked server's media was not found and log. - // the client has no way of telling anyways so this is a security bonus. - if services() - .globals - .prevent_media_downloads_from() - .contains(&server_name.to_owned()) - { - info!("Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc); - return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); - } + // we'll lie to the client and say the blocked server's media was not found and + // log. the client has no way of telling anyways so this is a security bonus. + if services().globals.prevent_media_downloads_from().contains(&server_name.to_owned()) { + info!( + "Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", + mxc + ); + return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); + } - let content_response = services() - .sending - .send_federation_request( - server_name, - get_content::v3::Request { - allow_remote: true, - server_name: server_name.to_owned(), - media_id, - timeout_ms, - allow_redirect, - }, - ) - .await?; + let content_response = services() + .sending + .send_federation_request( + server_name, + get_content::v3::Request { + allow_remote: true, + server_name: server_name.to_owned(), + media_id, + timeout_ms, + allow_redirect, + }, + ) + .await?; - services() - .media - .create( - mxc.to_owned(), - content_response.content_disposition.as_deref(), - content_response.content_type.as_deref(), - &content_response.file, - ) - .await?; + services() + .media + .create( + mxc.to_owned(), + content_response.content_disposition.as_deref(), + content_response.content_type.as_deref(), + &content_response.file, + ) + .await?; - Ok(content_response) + Ok(content_response) } /// # `GET /_matrix/media/v3/download/{serverName}/{mediaId}` @@ -156,37 +137,36 @@ pub async fn get_remote_content( /// /// - Only allows federation if `allow_remote` is true /// - Only redirects if `allow_redirect` is true -/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds -pub async fn get_content_route( - body: Ruma, -) -> Result { - let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); +/// - Uses client-provided `timeout_ms` if available, else defaults to 20 +/// seconds +pub async fn get_content_route(body: Ruma) -> Result { + let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); - if let Some(FileMeta { - content_disposition, - content_type, - file, - }) = services().media.get(mxc.clone()).await? - { - Ok(get_content::v3::Response { - file, - content_type, - content_disposition, - cross_origin_resource_policy: Some("cross-origin".to_owned()), - }) - } else if &*body.server_name != services().globals.server_name() && body.allow_remote { - let remote_content_response = get_remote_content( - &mxc, - &body.server_name, - body.media_id.clone(), - body.allow_redirect, - body.timeout_ms, - ) - .await?; - Ok(remote_content_response) - } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) - } + if let Some(FileMeta { + content_disposition, + content_type, + file, + }) = services().media.get(mxc.clone()).await? + { + Ok(get_content::v3::Response { + file, + content_type, + content_disposition, + cross_origin_resource_policy: Some("cross-origin".to_owned()), + }) + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + let remote_content_response = get_remote_content( + &mxc, + &body.server_name, + body.media_id.clone(), + body.allow_redirect, + body.timeout_ms, + ) + .await?; + Ok(remote_content_response) + } else { + Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + } } /// # `GET /_matrix/media/v3/download/{serverName}/{mediaId}/{fileName}` @@ -195,41 +175,44 @@ pub async fn get_content_route( /// /// - Only allows federation if `allow_remote` is true /// - Only redirects if `allow_redirect` is true -/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds +/// - Uses client-provided `timeout_ms` if available, else defaults to 20 +/// seconds pub async fn get_content_as_filename_route( - body: Ruma, + body: Ruma, ) -> Result { - let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); + let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); - if let Some(FileMeta { - content_type, file, .. - }) = services().media.get(mxc.clone()).await? - { - Ok(get_content_as_filename::v3::Response { - file, - content_type, - content_disposition: Some(format!("inline; filename={}", body.filename)), - cross_origin_resource_policy: Some("cross-origin".to_owned()), - }) - } else if &*body.server_name != services().globals.server_name() && body.allow_remote { - let remote_content_response = get_remote_content( - &mxc, - &body.server_name, - body.media_id.clone(), - body.allow_redirect, - body.timeout_ms, - ) - .await?; + if let Some(FileMeta { + content_type, + file, + .. + }) = services().media.get(mxc.clone()).await? + { + Ok(get_content_as_filename::v3::Response { + file, + content_type, + content_disposition: Some(format!("inline; filename={}", body.filename)), + cross_origin_resource_policy: Some("cross-origin".to_owned()), + }) + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + let remote_content_response = get_remote_content( + &mxc, + &body.server_name, + body.media_id.clone(), + body.allow_redirect, + body.timeout_ms, + ) + .await?; - Ok(get_content_as_filename::v3::Response { - content_disposition: Some(format!("inline: filename={}", body.filename)), - content_type: remote_content_response.content_type, - file: remote_content_response.file, - cross_origin_resource_policy: Some("cross-origin".to_owned()), - }) - } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) - } + Ok(get_content_as_filename::v3::Response { + content_disposition: Some(format!("inline: filename={}", body.filename)), + content_type: remote_content_response.content_type, + file: remote_content_response.file, + cross_origin_resource_policy: Some("cross-origin".to_owned()), + }) + } else { + Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + } } /// # `GET /_matrix/media/v3/thumbnail/{serverName}/{mediaId}` @@ -238,157 +221,152 @@ pub async fn get_content_as_filename_route( /// /// - Only allows federation if `allow_remote` is true /// - Only redirects if `allow_redirect` is true -/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds +/// - Uses client-provided `timeout_ms` if available, else defaults to 20 +/// seconds pub async fn get_content_thumbnail_route( - body: Ruma, + body: Ruma, ) -> Result { - let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); + let mxc = format!("mxc://{}/{}", body.server_name, body.media_id); - if let Some(FileMeta { - content_type, file, .. - }) = services() - .media - .get_thumbnail( - mxc.clone(), - body.width - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, - body.height - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, - ) - .await? - { - Ok(get_content_thumbnail::v3::Response { - file, - content_type, - cross_origin_resource_policy: Some("cross-origin".to_owned()), - }) - } else if &*body.server_name != services().globals.server_name() && body.allow_remote { - // we'll lie to the client and say the blocked server's media was not found and log. - // the client has no way of telling anyways so this is a security bonus. - if services() - .globals - .prevent_media_downloads_from() - .contains(&body.server_name.clone()) - { - info!("Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc); - return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); - } + if let Some(FileMeta { + content_type, + file, + .. + }) = services() + .media + .get_thumbnail( + mxc.clone(), + body.width.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?, + body.height.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?, + ) + .await? + { + Ok(get_content_thumbnail::v3::Response { + file, + content_type, + cross_origin_resource_policy: Some("cross-origin".to_owned()), + }) + } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + // we'll lie to the client and say the blocked server's media was not found and + // log. the client has no way of telling anyways so this is a security bonus. + if services().globals.prevent_media_downloads_from().contains(&body.server_name.clone()) { + info!( + "Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", + mxc + ); + return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")); + } - let get_thumbnail_response = services() - .sending - .send_federation_request( - &body.server_name, - get_content_thumbnail::v3::Request { - allow_remote: body.allow_remote, - height: body.height, - width: body.width, - method: body.method.clone(), - server_name: body.server_name.clone(), - media_id: body.media_id.clone(), - timeout_ms: body.timeout_ms, - allow_redirect: body.allow_redirect, - }, - ) - .await?; + let get_thumbnail_response = services() + .sending + .send_federation_request( + &body.server_name, + get_content_thumbnail::v3::Request { + allow_remote: body.allow_remote, + height: body.height, + width: body.width, + method: body.method.clone(), + server_name: body.server_name.clone(), + media_id: body.media_id.clone(), + timeout_ms: body.timeout_ms, + allow_redirect: body.allow_redirect, + }, + ) + .await?; - services() - .media - .upload_thumbnail( - mxc, - None, - get_thumbnail_response.content_type.as_deref(), - body.width.try_into().expect("all UInts are valid u32s"), - body.height.try_into().expect("all UInts are valid u32s"), - &get_thumbnail_response.file, - ) - .await?; + services() + .media + .upload_thumbnail( + mxc, + None, + get_thumbnail_response.content_type.as_deref(), + body.width.try_into().expect("all UInts are valid u32s"), + body.height.try_into().expect("all UInts are valid u32s"), + &get_thumbnail_response.file, + ) + .await?; - Ok(get_thumbnail_response) - } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) - } + Ok(get_thumbnail_response) + } else { + Err(Error::BadRequest(ErrorKind::NotFound, "Media not found.")) + } } async fn download_image(client: &reqwest::Client, url: &str) -> Result { - let image = client.get(url).send().await?.bytes().await?; - let mxc = format!( - "mxc://{}/{}", - services().globals.server_name(), - utils::random_string(MXC_LENGTH) - ); + let image = client.get(url).send().await?.bytes().await?; + let mxc = format!( + "mxc://{}/{}", + services().globals.server_name(), + utils::random_string(MXC_LENGTH) + ); - services() - .media - .create(mxc.clone(), None, None, &image) - .await?; + services().media.create(mxc.clone(), None, None, &image).await?; - let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() { - Err(_) => (None, None), - Ok(reader) => match reader.into_dimensions() { - Err(_) => (None, None), - Ok((width, height)) => (Some(width), Some(height)), - }, - }; + let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() { + Err(_) => (None, None), + Ok(reader) => match reader.into_dimensions() { + Err(_) => (None, None), + Ok((width, height)) => (Some(width), Some(height)), + }, + }; - Ok(UrlPreviewData { - image: Some(mxc), - image_size: Some(image.len()), - image_width: width, - image_height: height, - ..Default::default() - }) + Ok(UrlPreviewData { + image: Some(mxc), + image_size: Some(image.len()), + image_width: width, + image_height: height, + ..Default::default() + }) } async fn download_html(client: &reqwest::Client, url: &str) -> Result { - let mut response = client.get(url).send().await?; + let mut response = client.get(url).send().await?; - let mut bytes: Vec = Vec::new(); - while let Some(chunk) = response.chunk().await? { - bytes.extend_from_slice(&chunk); - if bytes.len() > services().globals.url_preview_max_spider_size() { - debug!("Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the response body and assuming our necessary data is in this range.", url, services().globals.url_preview_max_spider_size()); - break; - } - } - let body = String::from_utf8_lossy(&bytes); - let html = match HTML::from_string(body.to_string(), Some(url.to_owned())) { - Ok(html) => html, - Err(_) => { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Failed to parse HTML", - )) - } - }; + let mut bytes: Vec = Vec::new(); + while let Some(chunk) = response.chunk().await? { + bytes.extend_from_slice(&chunk); + if bytes.len() > services().globals.url_preview_max_spider_size() { + debug!( + "Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \ + response body and assuming our necessary data is in this range.", + url, + services().globals.url_preview_max_spider_size() + ); + break; + } + } + let body = String::from_utf8_lossy(&bytes); + let html = match HTML::from_string(body.to_string(), Some(url.to_owned())) { + Ok(html) => html, + Err(_) => return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to parse HTML")), + }; - let mut data = match html.opengraph.images.first() { - None => UrlPreviewData::default(), - Some(obj) => download_image(client, &obj.url).await?, - }; + let mut data = match html.opengraph.images.first() { + None => UrlPreviewData::default(), + Some(obj) => download_image(client, &obj.url).await?, + }; - let props = html.opengraph.properties; + let props = html.opengraph.properties; - /* use OpenGraph title/description, but fall back to HTML if not available */ - data.title = props.get("title").cloned().or(html.title); - data.description = props.get("description").cloned().or(html.description); + /* use OpenGraph title/description, but fall back to HTML if not available */ + data.title = props.get("title").cloned().or(html.title); + data.description = props.get("description").cloned().or(html.description); - Ok(data) + Ok(data) } fn url_request_allowed(addr: &IpAddr) -> bool { - // TODO: make this check ip_range_denylist + // TODO: make this check ip_range_denylist - // could be implemented with reqwest when it supports IP filtering: - // https://github.com/seanmonstar/reqwest/issues/1515 + // could be implemented with reqwest when it supports IP filtering: + // https://github.com/seanmonstar/reqwest/issues/1515 - // These checks have been taken from the Rust core/net/ipaddr.rs crate, - // IpAddr::V4.is_global() and IpAddr::V6.is_global(), as .is_global is not - // yet stabilized. TODO: Once this is stable, this match can be simplified. - match addr { - IpAddr::V4(ip4) => { - !(ip4.octets()[0] == 0 // "This network" + // These checks have been taken from the Rust core/net/ipaddr.rs crate, + // IpAddr::V4.is_global() and IpAddr::V6.is_global(), as .is_global is not + // yet stabilized. TODO: Once this is stable, this match can be simplified. + match addr { + IpAddr::V4(ip4) => { + !(ip4.octets()[0] == 0 // "This network" || ip4.is_private() || (ip4.octets()[0] == 100 && (ip4.octets()[1] & 0b1100_0000 == 0b0100_0000)) // is_shared() || ip4.is_loopback() @@ -399,9 +377,9 @@ fn url_request_allowed(addr: &IpAddr) -> bool { || (ip4.octets()[0] == 198 && (ip4.octets()[1] & 0xfe) == 18) // is_benchmarking() || (ip4.octets()[0] & 240 == 240 && !ip4.is_broadcast()) // is_reserved() || ip4.is_broadcast()) - } - IpAddr::V6(ip6) => { - !(ip6.is_unspecified() + }, + IpAddr::V6(ip6) => { + !(ip6.is_unspecified() || ip6.is_loopback() // IPv4-mapped Address (`::ffff:0:0/96`) || matches!(ip6.segments(), [0, 0, 0, 0, 0, 0xffff, _, _]) @@ -426,178 +404,127 @@ fn url_request_allowed(addr: &IpAddr) -> bool { || ((ip6.segments()[0] == 0x2001) && (ip6.segments()[1] == 0xdb8)) // is_documentation() || ((ip6.segments()[0] & 0xfe00) == 0xfc00) // is_unique_local() || ((ip6.segments()[0] & 0xffc0) == 0xfe80)) // is_unicast_link_local - } - } + }, + } } async fn request_url_preview(url: &str) -> Result { - let client = services().globals.url_preview_client(); - let response = client.head(url).send().await?; + let client = services().globals.url_preview_client(); + let response = client.head(url).send().await?; - if !response - .remote_addr() - .map_or(false, |a| url_request_allowed(&a.ip())) - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Requesting from this address is forbidden", - )); - } + if !response.remote_addr().map_or(false, |a| url_request_allowed(&a.ip())) { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Requesting from this address is forbidden", + )); + } - let content_type = match response - .headers() - .get(reqwest::header::CONTENT_TYPE) - .and_then(|x| x.to_str().ok()) - { - Some(ct) => ct, - None => { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Unknown Content-Type", - )) - } - }; - let data = match content_type { - html if html.starts_with("text/html") => download_html(&client, url).await?, - img if img.starts_with("image/") => download_image(&client, url).await?, - _ => { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Unsupported Content-Type", - )) - } - }; + let content_type = match response.headers().get(reqwest::header::CONTENT_TYPE).and_then(|x| x.to_str().ok()) { + Some(ct) => ct, + None => return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")), + }; + let data = match content_type { + html if html.starts_with("text/html") => download_html(&client, url).await?, + img if img.starts_with("image/") => download_image(&client, url).await?, + _ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")), + }; - services().media.set_url_preview(url, &data).await?; + services().media.set_url_preview(url, &data).await?; - Ok(data) + Ok(data) } async fn get_url_preview(url: &str) -> Result { - if let Some(preview) = services().media.get_url_preview(url).await { - return Ok(preview); - } + if let Some(preview) = services().media.get_url_preview(url).await { + return Ok(preview); + } - // ensure that only one request is made per URL - let mutex_request = Arc::clone( - services() - .media - .url_preview_mutex - .write() - .unwrap() - .entry(url.to_owned()) - .or_default(), - ); - let _request_lock = mutex_request.lock().await; + // ensure that only one request is made per URL + let mutex_request = + Arc::clone(services().media.url_preview_mutex.write().unwrap().entry(url.to_owned()).or_default()); + let _request_lock = mutex_request.lock().await; - match services().media.get_url_preview(url).await { - Some(preview) => Ok(preview), - None => request_url_preview(url).await, - } + match services().media.get_url_preview(url).await { + Some(preview) => Ok(preview), + None => request_url_preview(url).await, + } } fn url_preview_allowed(url_str: &str) -> bool { - let url: Url = match Url::parse(url_str) { - Ok(u) => u, - Err(e) => { - warn!("Failed to parse URL from a str: {}", e); - return false; - } - }; + let url: Url = match Url::parse(url_str) { + Ok(u) => u, + Err(e) => { + warn!("Failed to parse URL from a str: {}", e); + return false; + }, + }; - if ["http", "https"] - .iter() - .all(|&scheme| scheme != url.scheme().to_lowercase()) - { - debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url); - return false; - } + if ["http", "https"].iter().all(|&scheme| scheme != url.scheme().to_lowercase()) { + debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url); + return false; + } - let host = match url.host_str() { - None => { - debug!( - "Ignoring URL preview for a URL that does not have a host (?): {}", - url - ); - return false; - } - Some(h) => h.to_owned(), - }; + let host = match url.host_str() { + None => { + debug!("Ignoring URL preview for a URL that does not have a host (?): {}", url); + return false; + }, + Some(h) => h.to_owned(), + }; - let allowlist_domain_contains = services().globals.url_preview_domain_contains_allowlist(); - let allowlist_domain_explicit = services().globals.url_preview_domain_explicit_allowlist(); - let allowlist_url_contains = services().globals.url_preview_url_contains_allowlist(); + let allowlist_domain_contains = services().globals.url_preview_domain_contains_allowlist(); + let allowlist_domain_explicit = services().globals.url_preview_domain_explicit_allowlist(); + let allowlist_url_contains = services().globals.url_preview_url_contains_allowlist(); - if allowlist_domain_contains.contains(&"*".to_owned()) - || allowlist_domain_explicit.contains(&"*".to_owned()) - || allowlist_url_contains.contains(&"*".to_owned()) - { - debug!( - "Config key contains * which is allowing all URL previews. Allowing URL {}", - url - ); - return true; - } + if allowlist_domain_contains.contains(&"*".to_owned()) + || allowlist_domain_explicit.contains(&"*".to_owned()) + || allowlist_url_contains.contains(&"*".to_owned()) + { + debug!("Config key contains * which is allowing all URL previews. Allowing URL {}", url); + return true; + } - if !host.is_empty() { - if allowlist_domain_explicit.contains(&host) { - debug!( - "Host {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)", - &host - ); - return true; - } + if !host.is_empty() { + if allowlist_domain_explicit.contains(&host) { + debug!("Host {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)", &host); + return true; + } - if allowlist_domain_contains - .iter() - .any(|domain_s| domain_s.contains(&host.clone())) - { - debug!( - "Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", - &host - ); - return true; - } + if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&host.clone())) { + debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", &host); + return true; + } - if allowlist_url_contains - .iter() - .any(|url_s| url.to_string().contains(&url_s.to_string())) - { - debug!( - "URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)", - &host - ); - return true; - } + if allowlist_url_contains.iter().any(|url_s| url.to_string().contains(&url_s.to_string())) { + debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)", &host); + return true; + } - // check root domain if available and if user has root domain checks - if services().globals.url_preview_check_root_domain() { - debug!("Checking root domain"); - match host.split_once('.') { - None => return false, - Some((_, root_domain)) => { - if allowlist_domain_explicit.contains(&root_domain.to_owned()) { - debug!( - "Root domain {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)", - &root_domain - ); - return true; - } + // check root domain if available and if user has root domain checks + if services().globals.url_preview_check_root_domain() { + debug!("Checking root domain"); + match host.split_once('.') { + None => return false, + Some((_, root_domain)) => { + if allowlist_domain_explicit.contains(&root_domain.to_owned()) { + debug!( + "Root domain {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)", + &root_domain + ); + return true; + } - if allowlist_domain_contains - .iter() - .any(|domain_s| domain_s.contains(&root_domain.to_owned())) - { - debug!( - "Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", - &root_domain - ); - return true; - } - } - } - } - } + if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&root_domain.to_owned())) { + debug!( + "Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", + &root_domain + ); + return true; + } + }, + } + } + } - false + false } diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 4e8ebaa5..e046df96 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -1,170 +1,169 @@ +use std::{ + collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, + sync::{Arc, RwLock}, + time::{Duration, Instant}, +}; + use ruma::{ - api::{ - client::{ - error::ErrorKind, - membership::{ - ban_user, forget_room, get_member_events, invite_user, join_room_by_id, - join_room_by_id_or_alias, joined_members, joined_rooms, kick_user, leave_room, - unban_user, ThirdPartySigned, - }, - }, - federation::{self, membership::create_invite}, - }, - canonical_json::to_canonical_value, - events::{ - room::{ - join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, - power_levels::RoomPowerLevelsEventContent, - }, - StateEventType, TimelineEventType, - }, - serde::Base64, - state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, - OwnedServerName, OwnedUserId, RoomId, RoomVersionId, UserId, + api::{ + client::{ + error::ErrorKind, + membership::{ + ban_user, forget_room, get_member_events, invite_user, join_room_by_id, join_room_by_id_or_alias, + joined_members, joined_rooms, kick_user, leave_room, unban_user, ThirdPartySigned, + }, + }, + federation::{self, membership::create_invite}, + }, + canonical_json::to_canonical_value, + events::{ + room::{ + join_rules::{AllowRule, JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + power_levels::RoomPowerLevelsEventContent, + }, + StateEventType, TimelineEventType, + }, + serde::Base64, + state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, + OwnedUserId, RoomId, RoomVersionId, UserId, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use std::{ - collections::{hash_map::Entry, BTreeMap, HashMap, HashSet}, - sync::{Arc, RwLock}, - time::{Duration, Instant}, -}; use tracing::{debug, error, info, warn}; -use crate::{ - service::pdu::{gen_event_id_canonical_json, PduBuilder}, - services, utils, Error, PduEvent, Result, Ruma, -}; - use super::get_alias_helper; +use crate::{ + service::pdu::{gen_event_id_canonical_json, PduBuilder}, + services, utils, Error, PduEvent, Result, Ruma, +}; /// # `POST /_matrix/client/r0/rooms/{roomId}/join` /// /// Tries to join the sender user into a room. /// -/// - If the server knowns about this room: creates the join event and does auth rules locally -/// - If the server does not know about the room: asks other servers over federation -pub async fn join_room_by_id_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +/// - If the server knowns about this room: creates the join event and does auth +/// rules locally +/// - If the server does not know about the room: asks other servers over +/// federation +pub async fn join_room_by_id_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if services().rooms.metadata.is_banned(&body.room_id)? - && !services().users.is_admin(sender_user)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "This room is banned on this homeserver.", - )); - } + if services().rooms.metadata.is_banned(&body.room_id)? && !services().users.is_admin(sender_user)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "This room is banned on this homeserver.", + )); + } - let mut servers = Vec::new(); // There is no body.server_name for /roomId/join - servers.extend( - services() - .rooms - .state_cache - .invite_state(sender_user, &body.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(std::borrow::ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); + let mut servers = Vec::new(); // There is no body.server_name for /roomId/join + servers.extend( + services() + .rooms + .state_cache + .invite_state(sender_user, &body.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(std::borrow::ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); - // server names being permanently attached to room IDs may be potentally removed in the future (see MSC4051). - // for future compatibility with this, and just because it makes sense, we shouldn't fail if the room ID - // doesn't have a server name with it and just use at least the server name from the initial invite above - if let Some(server) = body.room_id.server_name() { - servers.push(server.into()); - } + // server names being permanently attached to room IDs may be potentally removed + // in the future (see MSC4051). for future compatibility with this, and just + // because it makes sense, we shouldn't fail if the room ID doesn't have a + // server name with it and just use at least the server name from the initial + // invite above + if let Some(server) = body.room_id.server_name() { + servers.push(server.into()); + } - join_room_by_id_helper( - body.sender_user.as_deref(), - &body.room_id, - body.reason.clone(), - &servers, - body.third_party_signed.as_ref(), - ) - .await + join_room_by_id_helper( + body.sender_user.as_deref(), + &body.room_id, + body.reason.clone(), + &servers, + body.third_party_signed.as_ref(), + ) + .await } /// # `POST /_matrix/client/r0/join/{roomIdOrAlias}` /// /// Tries to join the sender user into a room. /// -/// - If the server knowns about this room: creates the join event and does auth rules locally -/// - If the server does not know about the room: asks other servers over federation +/// - If the server knowns about this room: creates the join event and does auth +/// rules locally +/// - If the server does not know about the room: asks other servers over +/// federation pub async fn join_room_by_id_or_alias_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - let body = body.body; + let sender_user = body.sender_user.as_deref().expect("user is authenticated"); + let body = body.body; - let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) { - Ok(room_id) => { - if services().rooms.metadata.is_banned(&room_id)? - && !services().users.is_admin(sender_user)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "This room is banned on this homeserver.", - )); - } + let (servers, room_id) = match OwnedRoomId::try_from(body.room_id_or_alias) { + Ok(room_id) => { + if services().rooms.metadata.is_banned(&room_id)? && !services().users.is_admin(sender_user)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "This room is banned on this homeserver.", + )); + } - let mut servers = body.server_name.clone(); - servers.extend( - services() - .rooms - .state_cache - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(std::borrow::ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), - ); + let mut servers = body.server_name.clone(); + servers.extend( + services() + .rooms + .state_cache + .invite_state(sender_user, &room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(std::borrow::ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); - // server names being permanently attached to room IDs may be potentally removed in the future (see MSC4051). - // for future compatibility with this, and just because it makes sense, we shouldn't fail if the room ID - // doesn't have a server name with it and just use at least the server name from the initial invite above - if let Some(server) = room_id.server_name() { - servers.push(server.into()); - } + // server names being permanently attached to room IDs may be potentally removed + // in the future (see MSC4051). for future compatibility with this, and just + // because it makes sense, we shouldn't fail if the room ID doesn't have a + // server name with it and just use at least the server name from the initial + // invite above + if let Some(server) = room_id.server_name() { + servers.push(server.into()); + } - (servers, room_id) - } - Err(room_alias) => { - let response = get_alias_helper(room_alias).await?; + (servers, room_id) + }, + Err(room_alias) => { + let response = get_alias_helper(room_alias).await?; - if services().rooms.metadata.is_banned(&response.room_id)? - && !services().users.is_admin(sender_user)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "This room is banned on this homeserver.", - )); - } + if services().rooms.metadata.is_banned(&response.room_id)? && !services().users.is_admin(sender_user)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "This room is banned on this homeserver.", + )); + } - (response.servers, response.room_id) - } - }; + (response.servers, response.room_id) + }, + }; - let join_room_response = join_room_by_id_helper( - Some(sender_user), - &room_id, - body.reason.clone(), - &servers, - body.third_party_signed.as_ref(), - ) - .await?; + let join_room_response = join_room_by_id_helper( + Some(sender_user), + &room_id, + body.reason.clone(), + &servers, + body.third_party_signed.as_ref(), + ) + .await?; - Ok(join_room_by_id_or_alias::v3::Response { - room_id: join_room_response.room_id, - }) + Ok(join_room_by_id_or_alias::v3::Response { + room_id: join_room_response.room_id, + }) } /// # `POST /_matrix/client/v3/rooms/{roomId}/leave` @@ -172,336 +171,271 @@ pub async fn join_room_by_id_or_alias_route( /// Tries to leave the sender user from a room. /// /// - This should always work if the user is currently joined. -pub async fn leave_room_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn leave_room_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - leave_room(sender_user, &body.room_id, body.reason.clone()).await?; + leave_room(sender_user, &body.room_id, body.reason.clone()).await?; - Ok(leave_room::v3::Response::new()) + Ok(leave_room::v3::Response::new()) } /// # `POST /_matrix/client/r0/rooms/{roomId}/invite` /// /// Tries to send an invite event into the room. -pub async fn invite_user_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn invite_user_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().users.is_admin(sender_user)? && services().globals.block_non_admin_invites() { - info!( - "User {sender_user} is not an admin and attempted to send an invite to room {}", - &body.room_id - ); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Invites are not allowed on this server.", - )); - } + if !services().users.is_admin(sender_user)? && services().globals.block_non_admin_invites() { + info!( + "User {sender_user} is not an admin and attempted to send an invite to room {}", + &body.room_id + ); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Invites are not allowed on this server.", + )); + } - if services().rooms.metadata.is_banned(&body.room_id)? - && !services().users.is_admin(sender_user)? - { - info!( - "Local user {} who is not an admin attempted to send an invite for banned room {}.", - &sender_user, &body.room_id - ); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "This room is banned on this homeserver.", - )); - } + if services().rooms.metadata.is_banned(&body.room_id)? && !services().users.is_admin(sender_user)? { + info!( + "Local user {} who is not an admin attempted to send an invite for banned room {}.", + &sender_user, &body.room_id + ); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "This room is banned on this homeserver.", + )); + } - if let invite_user::v3::InvitationRecipient::UserId { user_id } = &body.recipient { - invite_helper( - sender_user, - user_id, - &body.room_id, - body.reason.clone(), - false, - ) - .await?; - Ok(invite_user::v3::Response {}) - } else { - Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) - } + if let invite_user::v3::InvitationRecipient::UserId { + user_id, + } = &body.recipient + { + invite_helper(sender_user, user_id, &body.room_id, body.reason.clone(), false).await?; + Ok(invite_user::v3::Response {}) + } else { + Err(Error::BadRequest(ErrorKind::NotFound, "User not found.")) + } } /// # `POST /_matrix/client/r0/rooms/{roomId}/kick` /// /// Tries to send a kick event into the room. -pub async fn kick_user_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn kick_user_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut event: RoomMemberEventContent = serde_json::from_str( - services() - .rooms - .state_accessor - .room_state_get( - &body.room_id, - &StateEventType::RoomMember, - body.user_id.as_ref(), - )? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot kick member that's not in the room.", - ))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + let mut event: RoomMemberEventContent = serde_json::from_str( + services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? + .ok_or(Error::BadRequest( + ErrorKind::BadState, + "Cannot kick member that's not in the room.", + ))? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Invalid member event in database."))?; - event.membership = MembershipState::Leave; - event.reason = body.reason.clone(); + event.membership = MembershipState::Leave; + event.reason = body.reason.clone(); - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + sender_user, + &body.room_id, + &state_lock, + ) + .await?; - drop(state_lock); + drop(state_lock); - Ok(kick_user::v3::Response::new()) + Ok(kick_user::v3::Response::new()) } /// # `POST /_matrix/client/r0/rooms/{roomId}/ban` /// /// Tries to send a ban event into the room. pub async fn ban_user_route(body: Ruma) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services() - .rooms - .state_accessor - .room_state_get( - &body.room_id, - &StateEventType::RoomMember, - body.user_id.as_ref(), - )? - .map_or( - Ok(RoomMemberEventContent { - membership: MembershipState::Ban, - displayname: services().users.displayname(&body.user_id)?, - avatar_url: services().users.avatar_url(&body.user_id)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(&body.user_id)?, - reason: body.reason.clone(), - join_authorized_via_users_server: None, - }), - |event| { - serde_json::from_str(event.content.get()) - .map(|event: RoomMemberEventContent| RoomMemberEventContent { - membership: MembershipState::Ban, - displayname: services() - .users - .displayname(&body.user_id) - .unwrap_or_default(), - avatar_url: services() - .users - .avatar_url(&body.user_id) - .unwrap_or_default(), - blurhash: services().users.blurhash(&body.user_id).unwrap_or_default(), - reason: body.reason.clone(), - ..event - }) - .map_err(|_| Error::bad_database("Invalid member event in database.")) - }, - )?; + let event = services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? + .map_or( + Ok(RoomMemberEventContent { + membership: MembershipState::Ban, + displayname: services().users.displayname(&body.user_id)?, + avatar_url: services().users.avatar_url(&body.user_id)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(&body.user_id)?, + reason: body.reason.clone(), + join_authorized_via_users_server: None, + }), + |event| { + serde_json::from_str(event.content.get()) + .map(|event: RoomMemberEventContent| RoomMemberEventContent { + membership: MembershipState::Ban, + displayname: services().users.displayname(&body.user_id).unwrap_or_default(), + avatar_url: services().users.avatar_url(&body.user_id).unwrap_or_default(), + blurhash: services().users.blurhash(&body.user_id).unwrap_or_default(), + reason: body.reason.clone(), + ..event + }) + .map_err(|_| Error::bad_database("Invalid member event in database.")) + }, + )?; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + sender_user, + &body.room_id, + &state_lock, + ) + .await?; - drop(state_lock); + drop(state_lock); - Ok(ban_user::v3::Response::new()) + Ok(ban_user::v3::Response::new()) } /// # `POST /_matrix/client/r0/rooms/{roomId}/unban` /// /// Tries to send an unban event into the room. -pub async fn unban_user_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn unban_user_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut event: RoomMemberEventContent = serde_json::from_str( - services() - .rooms - .state_accessor - .room_state_get( - &body.room_id, - &StateEventType::RoomMember, - body.user_id.as_ref(), - )? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "Cannot unban a user who is not banned.", - ))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid member event in database."))?; + let mut event: RoomMemberEventContent = serde_json::from_str( + services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &StateEventType::RoomMember, body.user_id.as_ref())? + .ok_or(Error::BadRequest(ErrorKind::BadState, "Cannot unban a user who is not banned."))? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Invalid member event in database."))?; - event.membership = MembershipState::Leave; - event.reason = body.reason.clone(); + event.membership = MembershipState::Leave; + event.reason = body.reason.clone(); - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + sender_user, + &body.room_id, + &state_lock, + ) + .await?; - drop(state_lock); + drop(state_lock); - Ok(unban_user::v3::Response::new()) + Ok(unban_user::v3::Response::new()) } /// # `POST /_matrix/client/v3/rooms/{roomId}/forget` /// /// Forgets about a room. /// -/// - If the sender user currently left the room: Stops sender user from receiving information about the room +/// - If the sender user currently left the room: Stops sender user from +/// receiving information about the room /// -/// Note: Other devices of the user have no way of knowing the room was forgotten, so this has to -/// be called from every device -pub async fn forget_room_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +/// Note: Other devices of the user have no way of knowing the room was +/// forgotten, so this has to be called from every device +pub async fn forget_room_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .rooms - .state_cache - .forget(&body.room_id, sender_user)?; + services().rooms.state_cache.forget(&body.room_id, sender_user)?; - Ok(forget_room::v3::Response::new()) + Ok(forget_room::v3::Response::new()) } /// # `POST /_matrix/client/r0/joined_rooms` /// /// Lists all rooms the user has joined. -pub async fn joined_rooms_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn joined_rooms_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - Ok(joined_rooms::v3::Response { - joined_rooms: services() - .rooms - .state_cache - .rooms_joined(sender_user) - .filter_map(std::result::Result::ok) - .collect(), - }) + Ok(joined_rooms::v3::Response { + joined_rooms: services() + .rooms + .state_cache + .rooms_joined(sender_user) + .filter_map(std::result::Result::ok) + .collect(), + }) } /// # `POST /_matrix/client/r0/rooms/{roomId}/members` /// -/// Lists all joined users in a room (TODO: at a specific point in time, with a specific membership). +/// Lists all joined users in a room (TODO: at a specific point in time, with a +/// specific membership). /// /// - Only works if the user is currently joined pub async fn get_member_events_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .rooms - .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view this room.", - )); - } + if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this room.", + )); + } - Ok(get_member_events::v3::Response { - chunk: services() - .rooms - .state_accessor - .room_state_full(&body.room_id) - .await? - .iter() - .filter(|(key, _)| key.0 == StateEventType::RoomMember) - .map(|(_, pdu)| pdu.to_member_event()) - .collect(), - }) + Ok(get_member_events::v3::Response { + chunk: services() + .rooms + .state_accessor + .room_state_full(&body.room_id) + .await? + .iter() + .filter(|(key, _)| key.0 == StateEventType::RoomMember) + .map(|(_, pdu)| pdu.to_member_event()) + .collect(), + }) } /// # `POST /_matrix/client/r0/rooms/{roomId}/joined_members` @@ -510,1219 +444,1030 @@ pub async fn get_member_events_route( /// /// - The sender user must be in the room /// - TODO: An appservice just needs a puppet joined -pub async fn joined_members_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn joined_members_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .rooms - .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view this room.", - )); - } + if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this room.", + )); + } - let mut joined = BTreeMap::new(); - for user_id in services() - .rooms - .state_cache - .room_members(&body.room_id) - .filter_map(std::result::Result::ok) - { - let display_name = services().users.displayname(&user_id)?; - let avatar_url = services().users.avatar_url(&user_id)?; + let mut joined = BTreeMap::new(); + for user_id in services().rooms.state_cache.room_members(&body.room_id).filter_map(std::result::Result::ok) { + let display_name = services().users.displayname(&user_id)?; + let avatar_url = services().users.avatar_url(&user_id)?; - joined.insert( - user_id, - joined_members::v3::RoomMember { - display_name, - avatar_url, - }, - ); - } + joined.insert( + user_id, + joined_members::v3::RoomMember { + display_name, + avatar_url, + }, + ); + } - Ok(joined_members::v3::Response { joined }) + Ok(joined_members::v3::Response { + joined, + }) } async fn join_room_by_id_helper( - sender_user: Option<&UserId>, - room_id: &RoomId, - reason: Option, - servers: &[OwnedServerName], - _third_party_signed: Option<&ThirdPartySigned>, + sender_user: Option<&UserId>, room_id: &RoomId, reason: Option, servers: &[OwnedServerName], + _third_party_signed: Option<&ThirdPartySigned>, ) -> Result { - let sender_user = sender_user.expect("user is authenticated"); + let sender_user = sender_user.expect("user is authenticated"); - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default()); + let state_lock = mutex_state.lock().await; - // Ask a remote server if we are not participating in this room - if !services() - .rooms - .state_cache - .server_in_room(services().globals.server_name(), room_id)? - { - info!("Joining {room_id} over federation."); + // Ask a remote server if we are not participating in this room + if !services().rooms.state_cache.server_in_room(services().globals.server_name(), room_id)? { + info!("Joining {room_id} over federation."); - let (make_join_response, remote_server) = - make_join_request(sender_user, room_id, servers).await?; + let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; - info!("make_join finished"); + info!("make_join finished"); - let room_version_id = match make_join_response.room_version { - Some(room_version) - if services() - .globals - .supported_room_versions() - .contains(&room_version) => - { - room_version - } - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; + let room_version_id = match make_join_response.room_version { + Some(room_version) if services().globals.supported_room_versions().contains(&room_version) => room_version, + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|_| { - Error::BadServerResponse("Invalid make_join event json received from server.") - })?; + let mut join_event_stub: CanonicalJsonObject = serde_json::from_str(make_join_response.event.get()) + .map_err(|_| Error::BadServerResponse("Invalid make_join event json received from server."))?; - let join_authorized_via_users_server = join_event_stub - .get("content") - .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() - }) - .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); + let join_authorized_via_users_server = join_event_stub + .get("content") + .map(|s| s.as_object()?.get("join_authorised_via_users_server")?.as_str()) + .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); - // TODO: Is origin needed? - join_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), - ); - join_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - join_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason, - join_authorized_via_users_server: join_authorized_via_users_server.clone(), - }) - .expect("event is valid, we just created it"), - ); + // TODO: Is origin needed? + join_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + join_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch().try_into().expect("Timestamp is valid js_int value"), + ), + ); + join_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason, + join_authorized_via_users_server: join_authorized_via_users_server.clone(), + }) + .expect("event is valid, we just created it"), + ); - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - join_event_stub.remove("event_id"); + // We don't leave the event id in the pdu because that's only allowed in v1 or + // v2 rooms + join_event_stub.remove("event_id"); - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut join_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); + // In order to create a compatible ref hash (EventID) the `hashes` field needs + // to be present + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut join_event_stub, + &room_version_id, + ) + .expect("event is valid, we just created it"); - // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()) - .expect("ruma's reference hashes are valid event ids"); + // Generate event id + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&join_event_stub, &room_version_id) + .expect("ruma can calculate reference hashes") + ); + let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); - // Add event_id back - join_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); + // Add event_id back + join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - // It has enough fields to be called a proper event now - let mut join_event = join_event_stub; + // It has enough fields to be called a proper event now + let mut join_event = join_event_stub; - info!("Asking {remote_server} for send_join in room {room_id}"); - let send_join_response = services() - .sending - .send_federation_request( - &remote_server, - federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, - }, - ) - .await?; + info!("Asking {remote_server} for send_join in room {room_id}"); + let send_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + omit_members: false, + }, + ) + .await?; - info!("send_join finished"); + info!("send_join finished"); - if join_authorized_via_users_server.is_some() { - match &room_version_id { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 => { - warn!("Found `join_authorised_via_users_server` but room {} is version {}. Ignoring.", room_id, &room_version_id); - } - // only room versions 8 and above using `join_authorized_via_users_server` (restricted joins) need to validate and send signatures - RoomVersionId::V8 | RoomVersionId::V9 | RoomVersionId::V10 | RoomVersionId::V11 => { - if let Some(signed_raw) = &send_join_response.room_state.event { - info!("There is a signed event. This room is probably using restricted joins. Adding signature to our event"); - let (signed_event_id, signed_value) = - match gen_event_id_canonical_json(signed_raw, &room_version_id) { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - } - }; + if join_authorized_via_users_server.is_some() { + match &room_version_id { + RoomVersionId::V1 + | RoomVersionId::V2 + | RoomVersionId::V3 + | RoomVersionId::V4 + | RoomVersionId::V5 + | RoomVersionId::V6 + | RoomVersionId::V7 => { + warn!( + "Found `join_authorised_via_users_server` but room {} is version {}. Ignoring.", + room_id, &room_version_id + ); + }, + // only room versions 8 and above using `join_authorized_via_users_server` (restricted joins) need to + // validate and send signatures + RoomVersionId::V8 | RoomVersionId::V9 | RoomVersionId::V10 | RoomVersionId::V11 => { + if let Some(signed_raw) = &send_join_response.room_state.event { + info!( + "There is a signed event. This room is probably using restricted joins. Adding signature \ + to our event" + ); + let (signed_event_id, signed_value) = + match gen_event_id_canonical_json(signed_raw, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + }, + }; - if signed_event_id != event_id { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Server sent event with wrong event id", - )); - } + if signed_event_id != event_id { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent event with wrong event id", + )); + } - match signed_value["signatures"] - .as_object() - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Server sent invalid signatures type", - )) - .and_then(|e| { - e.get(remote_server.as_str()).ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Server did not send its signature", - )) - }) { - Ok(signature) => { - join_event - .get_mut("signatures") - .expect("we created a valid pdu") - .as_object_mut() - .expect("we created a valid pdu") - .insert(remote_server.to_string(), signature.clone()); - } - Err(e) => { - warn!( - "Server {remote_server} sent invalid signature in sendjoin signatures for event {signed_value:?}: {e:?}", - ); - } - } - } - } - _ => { - warn!( - "Unexpected or unsupported room version {} for room {}", - &room_version_id, room_id - ); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - } - } - } + match signed_value["signatures"] + .as_object() + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent invalid signatures type", + )) + .and_then(|e| { + e.get(remote_server.as_str()).ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Server did not send its signature", + )) + }) { + Ok(signature) => { + join_event + .get_mut("signatures") + .expect("we created a valid pdu") + .as_object_mut() + .expect("we created a valid pdu") + .insert(remote_server.to_string(), signature.clone()); + }, + Err(e) => { + warn!( + "Server {remote_server} sent invalid signature in sendjoin signatures for event \ + {signed_value:?}: {e:?}", + ); + }, + } + } + }, + _ => { + warn!( + "Unexpected or unsupported room version {} for room {}", + &room_version_id, room_id + ); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + } + } - services().rooms.short.get_or_create_shortroomid(room_id)?; + services().rooms.short.get_or_create_shortroomid(room_id)?; - info!("Parsing join event"); - let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) - .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; + info!("Parsing join event"); + let parsed_join_pdu = PduEvent::from_id_val(event_id, join_event.clone()) + .map_err(|_| Error::BadServerResponse("Invalid join event PDU."))?; - let mut state = HashMap::new(); - let pub_key_map = RwLock::new(BTreeMap::new()); + let mut state = HashMap::new(); + let pub_key_map = RwLock::new(BTreeMap::new()); - info!("Fetching join signing keys"); - services() - .rooms - .event_handler - .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) - .await?; + info!("Fetching join signing keys"); + services() + .rooms + .event_handler + .fetch_join_signing_keys(&send_join_response, &room_version_id, &pub_key_map) + .await?; - info!("Going through send_join response room_state"); - for result in send_join_response - .room_state - .state - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) - { - let (event_id, value) = match result { - Ok(t) => t, - Err(_) => continue, - }; + info!("Going through send_join response room_state"); + for result in send_join_response + .room_state + .state + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + { + let (event_id, value) = match result { + Ok(t) => t, + Err(_) => continue, + }; - let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { - warn!("Invalid PDU in send_join response: {} {:?}", e, value); - Error::BadServerResponse("Invalid PDU in send_join response.") - })?; + let pdu = PduEvent::from_id_val(&event_id, value.clone()).map_err(|e| { + warn!("Invalid PDU in send_join response: {} {:?}", e, value); + Error::BadServerResponse("Invalid PDU in send_join response.") + })?; - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; - if let Some(state_key) = &pdu.state_key { - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; - state.insert(shortstatekey, pdu.event_id.clone()); - } - } + services().rooms.outlier.add_pdu_outlier(&event_id, &value)?; + if let Some(state_key) = &pdu.state_key { + let shortstatekey = + services().rooms.short.get_or_create_shortstatekey(&pdu.kind.to_string().into(), state_key)?; + state.insert(shortstatekey, pdu.event_id.clone()); + } + } - info!("Going through send_join response auth_chain"); - for result in send_join_response - .room_state - .auth_chain - .iter() - .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) - { - let (event_id, value) = match result { - Ok(t) => t, - Err(_) => continue, - }; + info!("Going through send_join response auth_chain"); + for result in send_join_response + .room_state + .auth_chain + .iter() + .map(|pdu| validate_and_add_event_id(pdu, &room_version_id, &pub_key_map)) + { + let (event_id, value) = match result { + Ok(t) => t, + Err(_) => continue, + }; - services() - .rooms - .outlier - .add_pdu_outlier(&event_id, &value)?; - } + services().rooms.outlier.add_pdu_outlier(&event_id, &value)?; + } - info!("Running send_join auth check"); + info!("Running send_join auth check"); - let auth_check = state_res::event_auth::auth_check( - &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), - &parsed_join_pdu, - None::, // TODO: third party invite - |k, s| { - services() - .rooms - .timeline - .get_pdu( - state.get( - &services() - .rooms - .short - .get_or_create_shortstatekey(&k.to_string().into(), s) - .ok()?, - )?, - ) - .ok()? - }, - ) - .map_err(|e| { - warn!("Auth check failed: {e}"); - Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed") - })?; + let auth_check = state_res::event_auth::auth_check( + &state_res::RoomVersion::new(&room_version_id).expect("room version is supported"), + &parsed_join_pdu, + None::, // TODO: third party invite + |k, s| { + services() + .rooms + .timeline + .get_pdu( + state + .get(&services().rooms.short.get_or_create_shortstatekey(&k.to_string().into(), s).ok()?)?, + ) + .ok()? + }, + ) + .map_err(|e| { + warn!("Auth check failed: {e}"); + Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed") + })?; - if !auth_check { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Auth check failed", - )); - } + if !auth_check { + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")); + } - info!("Saving state from send_join"); - let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( - room_id, - Arc::new( - state - .into_iter() - .map(|(k, id)| { - services() - .rooms - .state_compressor - .compress_state_event(k, &id) - }) - .collect::>()?, - ), - )?; + info!("Saving state from send_join"); + let (statehash_before_join, new, removed) = services().rooms.state_compressor.save_state( + room_id, + Arc::new( + state + .into_iter() + .map(|(k, id)| services().rooms.state_compressor.compress_state_event(k, &id)) + .collect::>()?, + ), + )?; - services() - .rooms - .state - .force_state(room_id, statehash_before_join, new, removed, &state_lock) - .await?; + services().rooms.state.force_state(room_id, statehash_before_join, new, removed, &state_lock).await?; - info!("Updating joined counts for new room"); - services().rooms.state_cache.update_joined_count(room_id)?; + info!("Updating joined counts for new room"); + services().rooms.state_cache.update_joined_count(room_id)?; - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; + // We append to state before appending the pdu, so we don't have a moment in + // time with the pdu without it's state. This is okay because append_pdu can't + // fail. + let statehash_after_join = services().rooms.state.append_to_state(&parsed_join_pdu)?; - info!("Appending new room join event"); - services() - .rooms - .timeline - .append_pdu( - &parsed_join_pdu, - join_event, - vec![(*parsed_join_pdu.event_id).to_owned()], - &state_lock, - ) - .await?; + info!("Appending new room join event"); + services() + .rooms + .timeline + .append_pdu( + &parsed_join_pdu, + join_event, + vec![(*parsed_join_pdu.event_id).to_owned()], + &state_lock, + ) + .await?; - info!("Setting final room state for new room"); - // We set the room state after inserting the pdu, so that we never have a moment in time - // where events in the current room state do not exist - services() - .rooms - .state - .set_room_state(room_id, statehash_after_join, &state_lock)?; - } else { - info!("We can join locally"); + info!("Setting final room state for new room"); + // We set the room state after inserting the pdu, so that we never have a moment + // in time where events in the current room state do not exist + services().rooms.state.set_room_state(room_id, statehash_after_join, &state_lock)?; + } else { + info!("We can join locally"); - let join_rules_event = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomJoinRules, - "", - )?; - let power_levels_event = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomPowerLevels, - "", - )?; + let join_rules_event = + services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; + let power_levels_event = + services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?; - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; - let power_levels_event_content: Option = power_levels_event - .as_ref() - .map(|power_levels_event| { - serde_json::from_str(power_levels_event.content.get()).map_err(|e| { - warn!("Invalid power levels event: {}", e); - Error::bad_database("Invalid power levels event in db.") - }) - }) - .transpose()?; + let join_rules_event_content: Option = join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }) + }) + .transpose()?; + let power_levels_event_content: Option = power_levels_event + .as_ref() + .map(|power_levels_event| { + serde_json::from_str(power_levels_event.content.get()).map_err(|e| { + warn!("Invalid power levels event: {}", e); + Error::bad_database("Invalid power levels event in db.") + }) + }) + .transpose()?; - let restriction_rooms = match join_rules_event_content { - Some(RoomJoinRulesEventContent { - join_rule: JoinRule::Restricted(restricted), - }) - | Some(RoomJoinRulesEventContent { - join_rule: JoinRule::KnockRestricted(restricted), - }) => restricted - .allow - .into_iter() - .filter_map(|a| match a { - AllowRule::RoomMembership(r) => Some(r.room_id), - _ => None, - }) - .collect(), - _ => Vec::new(), - }; + let restriction_rooms = match join_rules_event_content { + Some(RoomJoinRulesEventContent { + join_rule: JoinRule::Restricted(restricted), + }) + | Some(RoomJoinRulesEventContent { + join_rule: JoinRule::KnockRestricted(restricted), + }) => restricted + .allow + .into_iter() + .filter_map(|a| match a { + AllowRule::RoomMembership(r) => Some(r.room_id), + _ => None, + }) + .collect(), + _ => Vec::new(), + }; - let authorized_user = restriction_rooms - .iter() - .find_map(|restriction_room_id| { - if !services() - .rooms - .state_cache - .is_joined(sender_user, restriction_room_id) - .ok()? - { - return None; - } - let authorized_user = power_levels_event_content - .as_ref() - .and_then(|c| { - c.users - .iter() - .filter(|(uid, i)| { - uid.server_name() == services().globals.server_name() - && **i > ruma::int!(0) - && services() - .rooms - .state_cache - .is_joined(uid, restriction_room_id) - .unwrap_or(false) - }) - .max_by_key(|(_, i)| *i) - .map(|(u, _)| u.to_owned()) - }) - .or_else(|| { - // TODO: Check here if user is actually allowed to invite. Currently the auth - // check will just fail in this case. - services() - .rooms - .state_cache - .room_members(restriction_room_id) - .filter_map(std::result::Result::ok) - .find(|uid| uid.server_name() == services().globals.server_name()) - }); - Some(authorized_user) - }) - .flatten(); + let authorized_user = restriction_rooms + .iter() + .find_map(|restriction_room_id| { + if !services().rooms.state_cache.is_joined(sender_user, restriction_room_id).ok()? { + return None; + } + let authorized_user = power_levels_event_content + .as_ref() + .and_then(|c| { + c.users + .iter() + .filter(|(uid, i)| { + uid.server_name() == services().globals.server_name() + && **i > ruma::int!(0) && services() + .rooms + .state_cache + .is_joined(uid, restriction_room_id) + .unwrap_or(false) + }) + .max_by_key(|(_, i)| *i) + .map(|(u, _)| u.to_owned()) + }) + .or_else(|| { + // TODO: Check here if user is actually allowed to invite. Currently the auth + // check will just fail in this case. + services() + .rooms + .state_cache + .room_members(restriction_room_id) + .filter_map(std::result::Result::ok) + .find(|uid| uid.server_name() == services().globals.server_name()) + }); + Some(authorized_user) + }) + .flatten(); - let event = RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason: reason.clone(), - join_authorized_via_users_server: authorized_user, - }; + let event = RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: reason.clone(), + join_authorized_via_users_server: authorized_user, + }; - // Try normal join first - let error = match services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - }, - sender_user, - room_id, - &state_lock, - ) - .await - { - Ok(_event_id) => return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())), - Err(e) => e, - }; + // Try normal join first + let error = match services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await + { + Ok(_event_id) => return Ok(join_room_by_id::v3::Response::new(room_id.to_owned())), + Err(e) => e, + }; - if !restriction_rooms.is_empty() - && servers - .iter() - .filter(|s| *s != services().globals.server_name()) - .count() - > 0 - { - info!( - "We couldn't do the join locally, maybe federation can help to satisfy the restricted join requirements" - ); - let (make_join_response, remote_server) = - make_join_request(sender_user, room_id, servers).await?; + if !restriction_rooms.is_empty() + && servers.iter().filter(|s| *s != services().globals.server_name()).count() > 0 + { + info!( + "We couldn't do the join locally, maybe federation can help to satisfy the restricted join \ + requirements" + ); + let (make_join_response, remote_server) = make_join_request(sender_user, room_id, servers).await?; - let room_version_id = match make_join_response.room_version { - Some(room_version_id) - if services() - .globals - .supported_room_versions() - .contains(&room_version_id) => - { - room_version_id - } - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; - let mut join_event_stub: CanonicalJsonObject = - serde_json::from_str(make_join_response.event.get()).map_err(|_| { - Error::BadServerResponse("Invalid make_join event json received from server.") - })?; - let join_authorized_via_users_server = join_event_stub - .get("content") - .map(|s| { - s.as_object()? - .get("join_authorised_via_users_server")? - .as_str() - }) - .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); - // TODO: Is origin needed? - join_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), - ); - join_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - join_event_stub.insert( - "content".to_owned(), - to_canonical_value(RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason, - join_authorized_via_users_server, - }) - .expect("event is valid, we just created it"), - ); + let room_version_id = match make_join_response.room_version { + Some(room_version_id) if services().globals.supported_room_versions().contains(&room_version_id) => { + room_version_id + }, + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; + let mut join_event_stub: CanonicalJsonObject = serde_json::from_str(make_join_response.event.get()) + .map_err(|_| Error::BadServerResponse("Invalid make_join event json received from server."))?; + let join_authorized_via_users_server = join_event_stub + .get("content") + .map(|s| s.as_object()?.get("join_authorised_via_users_server")?.as_str()) + .and_then(|s| OwnedUserId::try_from(s.unwrap_or_default()).ok()); + // TODO: Is origin needed? + join_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + join_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch().try_into().expect("Timestamp is valid js_int value"), + ), + ); + join_event_stub.insert( + "content".to_owned(), + to_canonical_value(RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason, + join_authorized_via_users_server, + }) + .expect("event is valid, we just created it"), + ); - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - join_event_stub.remove("event_id"); + // We don't leave the event id in the pdu because that's only allowed in v1 or + // v2 rooms + join_event_stub.remove("event_id"); - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut join_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); + // In order to create a compatible ref hash (EventID) the `hashes` field needs + // to be present + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut join_event_stub, + &room_version_id, + ) + .expect("event is valid, we just created it"); - // Generate event id - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&join_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()) - .expect("ruma's reference hashes are valid event ids"); + // Generate event id + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&join_event_stub, &room_version_id) + .expect("ruma can calculate reference hashes") + ); + let event_id = + <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); - // Add event_id back - join_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); + // Add event_id back + join_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - // It has enough fields to be called a proper event now - let join_event = join_event_stub; + // It has enough fields to be called a proper event now + let join_event = join_event_stub; - let send_join_response = services() - .sending - .send_federation_request( - &remote_server, - federation::membership::create_join_event::v2::Request { - room_id: room_id.to_owned(), - event_id: event_id.to_owned(), - pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), - omit_members: false, - }, - ) - .await?; + let send_join_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_join_event::v2::Request { + room_id: room_id.to_owned(), + event_id: event_id.to_owned(), + pdu: PduEvent::convert_to_outgoing_federation_event(join_event.clone()), + omit_members: false, + }, + ) + .await?; - if let Some(signed_raw) = send_join_response.room_state.event { - let (signed_event_id, signed_value) = - match gen_event_id_canonical_json(&signed_raw, &room_version_id) { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - } - }; + if let Some(signed_raw) = send_join_response.room_state.event { + let (signed_event_id, signed_value) = match gen_event_id_canonical_json(&signed_raw, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + }, + }; - if signed_event_id != event_id { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Server sent event with wrong event id", - )); - } + if signed_event_id != event_id { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Server sent event with wrong event id", + )); + } - drop(state_lock); - let pub_key_map = RwLock::new(BTreeMap::new()); - services() - .rooms - .event_handler - .fetch_required_signing_keys([&signed_value], &pub_key_map) - .await?; - services() - .rooms - .event_handler - .handle_incoming_pdu( - &remote_server, - &signed_event_id, - room_id, - signed_value, - true, - &pub_key_map, - ) - .await?; - } else { - return Err(error); - } - } else { - return Err(error); - } - } + drop(state_lock); + let pub_key_map = RwLock::new(BTreeMap::new()); + services().rooms.event_handler.fetch_required_signing_keys([&signed_value], &pub_key_map).await?; + services() + .rooms + .event_handler + .handle_incoming_pdu(&remote_server, &signed_event_id, room_id, signed_value, true, &pub_key_map) + .await?; + } else { + return Err(error); + } + } else { + return Err(error); + } + } - Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) + Ok(join_room_by_id::v3::Response::new(room_id.to_owned())) } async fn make_join_request( - sender_user: &UserId, - room_id: &RoomId, - servers: &[OwnedServerName], -) -> Result<( - federation::membership::prepare_join_event::v1::Response, - OwnedServerName, -)> { - let mut make_join_response_and_server = Err(Error::BadServerResponse( - "No server available to assist in joining.", - )); + sender_user: &UserId, room_id: &RoomId, servers: &[OwnedServerName], +) -> Result<(federation::membership::prepare_join_event::v1::Response, OwnedServerName)> { + let mut make_join_response_and_server = Err(Error::BadServerResponse("No server available to assist in joining.")); - for remote_server in servers { - if remote_server == services().globals.server_name() { - continue; - } - info!("Asking {remote_server} for make_join"); - let make_join_response = services() - .sending - .send_federation_request( - remote_server, - federation::membership::prepare_join_event::v1::Request { - room_id: room_id.to_owned(), - user_id: sender_user.to_owned(), - ver: services().globals.supported_room_versions(), - }, - ) - .await; + for remote_server in servers { + if remote_server == services().globals.server_name() { + continue; + } + info!("Asking {remote_server} for make_join"); + let make_join_response = services() + .sending + .send_federation_request( + remote_server, + federation::membership::prepare_join_event::v1::Request { + room_id: room_id.to_owned(), + user_id: sender_user.to_owned(), + ver: services().globals.supported_room_versions(), + }, + ) + .await; - make_join_response_and_server = make_join_response.map(|r| (r, remote_server.clone())); + make_join_response_and_server = make_join_response.map(|r| (r, remote_server.clone())); - if make_join_response_and_server.is_ok() { - break; - } - } + if make_join_response_and_server.is_ok() { + break; + } + } - make_join_response_and_server + make_join_response_and_server } fn validate_and_add_event_id( - pdu: &RawJsonValue, - room_version: &RoomVersionId, - pub_key_map: &RwLock>>, + pdu: &RawJsonValue, room_version: &RoomVersionId, pub_key_map: &RwLock>>, ) -> Result<(OwnedEventId, CanonicalJsonObject)> { - let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&value, room_version) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); + let mut value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + let event_id = EventId::parse(format!( + "${}", + ruma::signatures::reference_hash(&value, room_version).expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); - let back_off = |id| match services() - .globals - .bad_event_ratelimiter - .write() - .unwrap() - .entry(id) - { - Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - }; + let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) { + Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + }; - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .unwrap() - .get(&event_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } + if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&event_id) { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", event_id); - return Err(Error::BadServerResponse("bad event, still backing off")); - } - } + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {}", event_id); + return Err(Error::BadServerResponse("bad event, still backing off")); + } + } - if let Err(e) = ruma::signatures::verify_event( - &*pub_key_map - .read() - .map_err(|_| Error::bad_database("RwLock is poisoned."))?, - &value, - room_version, - ) { - warn!("Event {} failed verification {:?} {}", event_id, pdu, e); - back_off(event_id); - return Err(Error::BadServerResponse("Event failed verification.")); - } + if let Err(e) = ruma::signatures::verify_event( + &*pub_key_map.read().map_err(|_| Error::bad_database("RwLock is poisoned."))?, + &value, + room_version, + ) { + warn!("Event {} failed verification {:?} {}", event_id, pdu, e); + back_off(event_id); + return Err(Error::BadServerResponse("Event failed verification.")); + } - value.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); + value.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - Ok((event_id, value)) + Ok((event_id, value)) } pub(crate) async fn invite_helper( - sender_user: &UserId, - user_id: &UserId, - room_id: &RoomId, - reason: Option, - is_direct: bool, + sender_user: &UserId, user_id: &UserId, room_id: &RoomId, reason: Option, is_direct: bool, ) -> Result<()> { - if !services().users.is_admin(user_id)? && services().globals.block_non_admin_invites() { - info!( - "User {sender_user} is not an admin and attempted to send an invite to room {room_id}" - ); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Invites are not allowed on this server.", - )); - } + if !services().users.is_admin(user_id)? && services().globals.block_non_admin_invites() { + info!("User {sender_user} is not an admin and attempted to send an invite to room {room_id}"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Invites are not allowed on this server.", + )); + } - if user_id.server_name() != services().globals.server_name() { - let (pdu, pdu_json, invite_room_state) = { - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + if user_id.server_name() != services().globals.server_name() { + let (pdu, pdu_json, invite_room_state) = { + let mutex_state = Arc::clone( + services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default(), + ); + let state_lock = mutex_state.lock().await; - let content = to_raw_value(&RoomMemberEventContent { - avatar_url: services().users.avatar_url(user_id)?, - displayname: None, - is_direct: Some(is_direct), - membership: MembershipState::Invite, - third_party_invite: None, - blurhash: None, - reason, - join_authorized_via_users_server: None, - }) - .expect("member event is valid value"); + let content = to_raw_value(&RoomMemberEventContent { + avatar_url: services().users.avatar_url(user_id)?, + displayname: None, + is_direct: Some(is_direct), + membership: MembershipState::Invite, + third_party_invite: None, + blurhash: None, + reason, + join_authorized_via_users_server: None, + }) + .expect("member event is valid value"); - let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - sender_user, - room_id, - &state_lock, - )?; + let (pdu, pdu_json) = services().rooms.timeline.create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + )?; - let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; + let invite_room_state = services().rooms.state.calculate_invite_state(&pdu)?; - drop(state_lock); + drop(state_lock); - (pdu, pdu_json, invite_room_state) - }; + (pdu, pdu_json, invite_room_state) + }; - let room_version_id = services().rooms.state.get_room_version(room_id)?; + let room_version_id = services().rooms.state.get_room_version(room_id)?; - let response = services() - .sending - .send_federation_request( - user_id.server_name(), - create_invite::v2::Request { - room_id: room_id.to_owned(), - event_id: (*pdu.event_id).to_owned(), - room_version: room_version_id.clone(), - event: PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), - invite_room_state, - }, - ) - .await?; + let response = services() + .sending + .send_federation_request( + user_id.server_name(), + create_invite::v2::Request { + room_id: room_id.to_owned(), + event_id: (*pdu.event_id).to_owned(), + room_version: room_version_id.clone(), + event: PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), + invite_room_state, + }, + ) + .await?; - let pub_key_map = RwLock::new(BTreeMap::new()); + let pub_key_map = RwLock::new(BTreeMap::new()); - // We do not add the event_id field to the pdu here because of signature and hashes checks - let (event_id, value) = match gen_event_id_canonical_json(&response.event, &room_version_id) - { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - } - }; + // We do not add the event_id field to the pdu here because of signature and + // hashes checks + let (event_id, value) = match gen_event_id_canonical_json(&response.event, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + }, + }; - if *pdu.event_id != *event_id { - warn!("Server {} changed invite event, that's not allowed in the spec: ours: {:?}, theirs: {:?}", user_id.server_name(), pdu_json, value); - } + if *pdu.event_id != *event_id { + warn!( + "Server {} changed invite event, that's not allowed in the spec: ours: {:?}, theirs: {:?}", + user_id.server_name(), + pdu_json, + value + ); + } - let origin: OwnedServerName = serde_json::from_value( - serde_json::to_value(value.get("origin").ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event needs an origin field.", - ))?) - .expect("CanonicalJson is valid json value"), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; + let origin: OwnedServerName = serde_json::from_value( + serde_json::to_value( + value + .get("origin") + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event needs an origin field."))?, + ) + .expect("CanonicalJson is valid json value"), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - services() - .rooms - .event_handler - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; + services().rooms.event_handler.fetch_required_signing_keys([&value], &pub_key_map).await?; - let pdu_id: Vec = services() - .rooms - .event_handler - .handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) - .await? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; + let pdu_id: Vec = services() + .rooms + .event_handler + .handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) + .await? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not accept incoming PDU as timeline event.", + ))?; - // Bind to variable because of lifetimes - let servers = services() - .rooms - .state_cache - .room_servers(room_id) - .filter_map(std::result::Result::ok) - .filter(|server| &**server != services().globals.server_name()); + // Bind to variable because of lifetimes + let servers = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(std::result::Result::ok) + .filter(|server| &**server != services().globals.server_name()); - services().sending.send_pdu(servers, &pdu_id)?; + services().sending.send_pdu(servers, &pdu_id)?; - return Ok(()); - } + return Ok(()); + } - if !services() - .rooms - .state_cache - .is_joined(sender_user, room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view this room.", - )); - } + if !services().rooms.state_cache.is_joined(sender_user, room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this room.", + )); + } - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default()); + let state_lock = mutex_state.lock().await; - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: services().users.displayname(user_id)?, - avatar_url: services().users.avatar_url(user_id)?, - is_direct: Some(is_direct), - third_party_invite: None, - blurhash: services().users.blurhash(user_id)?, - reason, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - sender_user, - room_id, - &state_lock, - ) - .await?; + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Invite, + displayname: services().users.displayname(user_id)?, + avatar_url: services().users.avatar_url(user_id)?, + is_direct: Some(is_direct), + third_party_invite: None, + blurhash: services().users.blurhash(user_id)?, + reason, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await?; - drop(state_lock); + drop(state_lock); - Ok(()) + Ok(()) } // Make a user leave all their joined rooms pub async fn leave_all_rooms(user_id: &UserId) -> Result<()> { - let all_rooms = services() - .rooms - .state_cache - .rooms_joined(user_id) - .chain( - services() - .rooms - .state_cache - .rooms_invited(user_id) - .map(|t| t.map(|(r, _)| r)), - ) - .collect::>(); + let all_rooms = services() + .rooms + .state_cache + .rooms_joined(user_id) + .chain(services().rooms.state_cache.rooms_invited(user_id).map(|t| t.map(|(r, _)| r))) + .collect::>(); - for room_id in all_rooms { - let room_id = match room_id { - Ok(room_id) => room_id, - Err(_) => continue, - }; + for room_id in all_rooms { + let room_id = match room_id { + Ok(room_id) => room_id, + Err(_) => continue, + }; - let _ = leave_room(user_id, &room_id, None).await; - } + let _ = leave_room(user_id, &room_id, None).await; + } - Ok(()) + Ok(()) } pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option) -> Result<()> { - // Ask a remote server if we don't have this room - if !services().rooms.metadata.exists(room_id)? - && room_id.server_name() != Some(services().globals.server_name()) - { - if let Err(e) = remote_leave_room(user_id, room_id).await { - warn!("Failed to leave room {} remotely: {}", user_id, e); - // Don't tell the client about this error - } + // Ask a remote server if we don't have this room + if !services().rooms.metadata.exists(room_id)? && room_id.server_name() != Some(services().globals.server_name()) { + if let Err(e) = remote_leave_room(user_id, room_id).await { + warn!("Failed to leave room {} remotely: {}", user_id, e); + // Don't tell the client about this error + } - let last_state = services() - .rooms - .state_cache - .invite_state(user_id, room_id)? - .map_or_else( - || services().rooms.state_cache.left_state(user_id, room_id), - |s| Ok(Some(s)), - )?; + let last_state = services() + .rooms + .state_cache + .invite_state(user_id, room_id)? + .map_or_else(|| services().rooms.state_cache.left_state(user_id, room_id), |s| Ok(Some(s)))?; - // We always drop the invite, we can't rely on other servers - services() - .rooms - .state_cache - .update_membership( - room_id, - user_id, - RoomMemberEventContent::new(MembershipState::Leave), - user_id, - last_state, - true, - ) - .await?; - } else { - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + // We always drop the invite, we can't rely on other servers + services() + .rooms + .state_cache + .update_membership( + room_id, + user_id, + RoomMemberEventContent::new(MembershipState::Leave), + user_id, + last_state, + true, + ) + .await?; + } else { + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default()); + let state_lock = mutex_state.lock().await; - let member_event = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomMember, - user_id.as_str(), - )?; + let member_event = + services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?; - // Fix for broken rooms - let member_event = match member_event { - None => { - error!("Trying to leave a room you are not a member of."); + // Fix for broken rooms + let member_event = match member_event { + None => { + error!("Trying to leave a room you are not a member of."); - services() - .rooms - .state_cache - .update_membership( - room_id, - user_id, - RoomMemberEventContent::new(MembershipState::Leave), - user_id, - None, - true, - ) - .await?; - return Ok(()); - } - Some(e) => e, - }; + services() + .rooms + .state_cache + .update_membership( + room_id, + user_id, + RoomMemberEventContent::new(MembershipState::Leave), + user_id, + None, + true, + ) + .await?; + return Ok(()); + }, + Some(e) => e, + }; - let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()) - .map_err(|e| { - error!("Invalid room member event in database: {}", e); - Error::bad_database("Invalid member event in database.") - })?; + let mut event: RoomMemberEventContent = serde_json::from_str(member_event.content.get()).map_err(|e| { + error!("Invalid room member event in database: {}", e); + Error::bad_database("Invalid member event in database.") + })?; - event.membership = MembershipState::Leave; - event.reason = reason; + event.membership = MembershipState::Leave; + event.reason = reason; - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&event).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - user_id, - room_id, - &state_lock, - ) - .await?; - } + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&event).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + user_id, + room_id, + &state_lock, + ) + .await?; + } - Ok(()) + Ok(()) } async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut make_leave_response_and_server = Err(Error::BadServerResponse( - "No server available to assist in leaving.", - )); + let mut make_leave_response_and_server = Err(Error::BadServerResponse("No server available to assist in leaving.")); - let invite_state = services() - .rooms - .state_cache - .invite_state(user_id, room_id)? - .ok_or(Error::BadRequest( - ErrorKind::BadState, - "User is not invited.", - ))?; + let invite_state = services() + .rooms + .state_cache + .invite_state(user_id, room_id)? + .ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; - let servers: HashSet<_> = invite_state - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(std::borrow::ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(); + let servers: HashSet<_> = invite_state + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(std::borrow::ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect(); - for remote_server in servers { - let make_leave_response = services() - .sending - .send_federation_request( - &remote_server, - federation::membership::prepare_leave_event::v1::Request { - room_id: room_id.to_owned(), - user_id: user_id.to_owned(), - }, - ) - .await; + for remote_server in servers { + let make_leave_response = services() + .sending + .send_federation_request( + &remote_server, + federation::membership::prepare_leave_event::v1::Request { + room_id: room_id.to_owned(), + user_id: user_id.to_owned(), + }, + ) + .await; - make_leave_response_and_server = make_leave_response.map(|r| (r, remote_server)); + make_leave_response_and_server = make_leave_response.map(|r| (r, remote_server)); - if make_leave_response_and_server.is_ok() { - break; - } - } + if make_leave_response_and_server.is_ok() { + break; + } + } - let (make_leave_response, remote_server) = make_leave_response_and_server?; + let (make_leave_response, remote_server) = make_leave_response_and_server?; - let room_version_id = match make_leave_response.room_version { - Some(version) - if services() - .globals - .supported_room_versions() - .contains(&version) => - { - version - } - _ => return Err(Error::BadServerResponse("Room version is not supported")), - }; + let room_version_id = match make_leave_response.room_version { + Some(version) if services().globals.supported_room_versions().contains(&version) => version, + _ => return Err(Error::BadServerResponse("Room version is not supported")), + }; - let mut leave_event_stub = serde_json::from_str::( - make_leave_response.event.get(), - ) - .map_err(|_| Error::BadServerResponse("Invalid make_leave event json received from server."))?; + let mut leave_event_stub = serde_json::from_str::(make_leave_response.event.get()) + .map_err(|_| Error::BadServerResponse("Invalid make_leave event json received from server."))?; - // TODO: Is origin needed? - leave_event_stub.insert( - "origin".to_owned(), - CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), - ); - leave_event_stub.insert( - "origin_server_ts".to_owned(), - CanonicalJsonValue::Integer( - utils::millis_since_unix_epoch() - .try_into() - .expect("Timestamp is valid js_int value"), - ), - ); - // We don't leave the event id in the pdu because that's only allowed in v1 or v2 rooms - leave_event_stub.remove("event_id"); + // TODO: Is origin needed? + leave_event_stub.insert( + "origin".to_owned(), + CanonicalJsonValue::String(services().globals.server_name().as_str().to_owned()), + ); + leave_event_stub.insert( + "origin_server_ts".to_owned(), + CanonicalJsonValue::Integer( + utils::millis_since_unix_epoch().try_into().expect("Timestamp is valid js_int value"), + ), + ); + // We don't leave the event id in the pdu because that's only allowed in v1 or + // v2 rooms + leave_event_stub.remove("event_id"); - // In order to create a compatible ref hash (EventID) the `hashes` field needs to be present - ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut leave_event_stub, - &room_version_id, - ) - .expect("event is valid, we just created it"); + // In order to create a compatible ref hash (EventID) the `hashes` field needs + // to be present + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut leave_event_stub, + &room_version_id, + ) + .expect("event is valid, we just created it"); - // Generate event id - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&leave_event_stub, &room_version_id) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); + // Generate event id + let event_id = EventId::parse(format!( + "${}", + ruma::signatures::reference_hash(&leave_event_stub, &room_version_id) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); - // Add event_id back - leave_event_stub.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); + // Add event_id back + leave_event_stub.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - // It has enough fields to be called a proper event now - let leave_event = leave_event_stub; + // It has enough fields to be called a proper event now + let leave_event = leave_event_stub; - services() - .sending - .send_federation_request( - &remote_server, - federation::membership::create_leave_event::v2::Request { - room_id: room_id.to_owned(), - event_id, - pdu: PduEvent::convert_to_outgoing_federation_event(leave_event.clone()), - }, - ) - .await?; + services() + .sending + .send_federation_request( + &remote_server, + federation::membership::create_leave_event::v2::Request { + room_id: room_id.to_owned(), + event_id, + pdu: PduEvent::convert_to_outgoing_federation_event(leave_event.clone()), + }, + ) + .await?; - Ok(()) + Ok(()) } diff --git a/src/api/client_server/message.rs b/src/api/client_server/message.rs index 156e1f5c..38e9b65f 100644 --- a/src/api/client_server/message.rs +++ b/src/api/client_server/message.rs @@ -1,316 +1,284 @@ -use crate::{ - service::{pdu::PduBuilder, rooms::timeline::PduCount}, - services, utils, Error, Result, Ruma, +use std::{ + collections::{BTreeMap, HashSet}, + sync::Arc, }; + use ruma::{ - api::client::{ - error::ErrorKind, - message::{get_message_events, send_message_event}, - }, - events::{StateEventType, TimelineEventType}, + api::client::{ + error::ErrorKind, + message::{get_message_events, send_message_event}, + }, + events::{StateEventType, TimelineEventType}, }; use serde_json::from_str; -use std::{ - collections::{BTreeMap, HashSet}, - sync::Arc, + +use crate::{ + service::{pdu::PduBuilder, rooms::timeline::PduCount}, + services, utils, Error, Result, Ruma, }; /// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}` /// /// Send a message event into the room. /// -/// - Is a NOOP if the txn id was already used before and returns the same event id again +/// - Is a NOOP if the txn id was already used before and returns the same event +/// id again /// - The only requirement for the content is that it has to be valid json -/// - Tries to send the event into the room, auth rules will determine if it is allowed +/// - Tries to send the event into the room, auth rules will determine if it is +/// allowed pub async fn send_message_event_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_deref(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_deref(); - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - // Forbid m.room.encrypted if encryption is disabled - if TimelineEventType::RoomEncrypted == body.event_type.to_string().into() - && !services().globals.allow_encryption() - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Encryption has been disabled", - )); - } + // Forbid m.room.encrypted if encryption is disabled + if TimelineEventType::RoomEncrypted == body.event_type.to_string().into() && !services().globals.allow_encryption() + { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Encryption has been disabled")); + } - // certain event types require certain fields to be valid in request bodies. - // this helps prevent attempting to handle events that we can't deserialise later so don't waste resources on it. - // - // see https://spec.matrix.org/v1.9/client-server-api/#events-2 for what's required per event type. - match body.event_type.to_string().into() { - TimelineEventType::RoomMessage => { - let body_field = body.body.body.get_field::("body"); - let msgtype_field = body.body.body.get_field::("msgtype"); + // certain event types require certain fields to be valid in request bodies. + // this helps prevent attempting to handle events that we can't deserialise + // later so don't waste resources on it. + // + // see https://spec.matrix.org/v1.9/client-server-api/#events-2 for what's required per event type. + match body.event_type.to_string().into() { + TimelineEventType::RoomMessage => { + let body_field = body.body.body.get_field::("body"); + let msgtype_field = body.body.body.get_field::("msgtype"); - if body_field.is_err() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "'body' field in JSON request is invalid", - )); - } + if body_field.is_err() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "'body' field in JSON request is invalid", + )); + } - if msgtype_field.is_err() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "'msgtype' field in JSON request is invalid", - )); - } - } - TimelineEventType::RoomName => { - let name_field = body.body.body.get_field::("name"); + if msgtype_field.is_err() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "'msgtype' field in JSON request is invalid", + )); + } + }, + TimelineEventType::RoomName => { + let name_field = body.body.body.get_field::("name"); - if name_field.is_err() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "'name' field in JSON request is invalid", - )); - } - } - TimelineEventType::RoomTopic => { - let topic_field = body.body.body.get_field::("topic"); + if name_field.is_err() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "'name' field in JSON request is invalid", + )); + } + }, + TimelineEventType::RoomTopic => { + let topic_field = body.body.body.get_field::("topic"); - if topic_field.is_err() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "'topic' field in JSON request is invalid", - )); - } - } - _ => {} // event may be custom/experimental or can be empty don't do anything with it - }; + if topic_field.is_err() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "'topic' field in JSON request is invalid", + )); + } + }, + _ => {}, // event may be custom/experimental or can be empty don't do anything with it + }; - // Check if this is a new transaction id - if let Some(response) = - services() - .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? - { - // The client might have sent a txnid of the /sendToDevice endpoint - // This txnid has no response associated with it - if response.is_empty() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to use txn id already used for an incompatible endpoint.", - )); - } + // Check if this is a new transaction id + if let Some(response) = services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)? { + // The client might have sent a txnid of the /sendToDevice endpoint + // This txnid has no response associated with it + if response.is_empty() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Tried to use txn id already used for an incompatible endpoint.", + )); + } - let event_id = utils::string_from_bytes(&response) - .map_err(|_| Error::bad_database("Invalid txnid bytes in database."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid event id in txnid data."))?; - return Ok(send_message_event::v3::Response { event_id }); - } + let event_id = utils::string_from_bytes(&response) + .map_err(|_| Error::bad_database("Invalid txnid bytes in database."))? + .try_into() + .map_err(|_| Error::bad_database("Invalid event id in txnid data."))?; + return Ok(send_message_event::v3::Response { + event_id, + }); + } - let mut unsigned = BTreeMap::new(); - unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); + let mut unsigned = BTreeMap::new(); + unsigned.insert("transaction_id".to_owned(), body.txn_id.to_string().into()); - let event_id = services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: body.event_type.to_string().into(), - content: from_str(body.body.body.json().get()) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, - unsigned: Some(unsigned), - state_key: None, - redacts: None, - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; + let event_id = services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: body.event_type.to_string().into(), + content: from_str(body.body.body.json().get()) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid JSON body."))?, + unsigned: Some(unsigned), + state_key: None, + redacts: None, + }, + sender_user, + &body.room_id, + &state_lock, + ) + .await?; - services().transaction_ids.add_txnid( - sender_user, - sender_device, - &body.txn_id, - event_id.as_bytes(), - )?; + services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?; - drop(state_lock); + drop(state_lock); - Ok(send_message_event::v3::Response::new( - (*event_id).to_owned(), - )) + Ok(send_message_event::v3::Response::new((*event_id).to_owned())) } /// # `GET /_matrix/client/r0/rooms/{roomId}/messages` /// /// Allows paginating through room history. /// -/// - Only works if the user is joined (TODO: always allow, but only show events where the user was +/// - Only works if the user is joined (TODO: always allow, but only show events +/// where the user was /// joined, depending on history_visibility) pub async fn get_message_events_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, - None => match body.dir { - ruma::api::Direction::Forward => PduCount::min(), - ruma::api::Direction::Backward => PduCount::max(), - }, - }; + let from = match body.from.clone() { + Some(from) => PduCount::try_from_string(&from)?, + None => match body.dir { + ruma::api::Direction::Forward => PduCount::min(), + ruma::api::Direction::Backward => PduCount::max(), + }, + }; - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); + let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); - services().rooms.lazy_loading.lazy_load_confirm_delivery( - sender_user, - sender_device, - &body.room_id, - from, - )?; + services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?; - let limit = u64::from(body.limit).min(100) as usize; + let limit = u64::from(body.limit).min(100) as usize; - let next_token; + let next_token; - let mut resp = get_message_events::v3::Response::new(); + let mut resp = get_message_events::v3::Response::new(); - let mut lazy_loaded = HashSet::new(); + let mut lazy_loaded = HashSet::new(); - match body.dir { - ruma::api::Direction::Forward => { - let events_after: Vec<_> = services() - .rooms - .timeline - .pdus_after(sender_user, &body.room_id, from)? - .take(limit) - .filter_map(std::result::Result::ok) // Filter out buggy events - .filter(|(_, pdu)| { - services() - .rooms - .state_accessor - .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) - .unwrap_or(false) - }) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` - .collect(); + match body.dir { + ruma::api::Direction::Forward => { + let events_after: Vec<_> = services() + .rooms + .timeline + .pdus_after(sender_user, &body.room_id, from)? + .take(limit) + .filter_map(std::result::Result::ok) // Filter out buggy events + .filter(|(_, pdu)| { + services() + .rooms + .state_accessor + .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) + .unwrap_or(false) + }) + .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .collect(); - for (_, event) in &events_after { - /* TODO: Remove this when these are resolved: - * https://github.com/vector-im/element-android/issues/3417 - * https://github.com/vector-im/element-web/issues/21034 - if !services().rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &body.room_id, - &event.sender, - )? { - lazy_loaded.insert(event.sender.clone()); - } - */ - lazy_loaded.insert(event.sender.clone()); - } + for (_, event) in &events_after { + /* TODO: Remove this when these are resolved: + * https://github.com/vector-im/element-android/issues/3417 + * https://github.com/vector-im/element-web/issues/21034 + if !services().rooms.lazy_loading.lazy_load_was_sent_before( + sender_user, + sender_device, + &body.room_id, + &event.sender, + )? { + lazy_loaded.insert(event.sender.clone()); + } + */ + lazy_loaded.insert(event.sender.clone()); + } - next_token = events_after.last().map(|(count, _)| count).copied(); + next_token = events_after.last().map(|(count, _)| count).copied(); - let events_after: Vec<_> = events_after - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); + let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); - resp.start = from.stringify(); - resp.end = next_token.map(|count| count.stringify()); - resp.chunk = events_after; - } - ruma::api::Direction::Backward => { - services() - .rooms - .timeline - .backfill_if_required(&body.room_id, from) - .await?; - let events_before: Vec<_> = services() - .rooms - .timeline - .pdus_until(sender_user, &body.room_id, from)? - .take(limit) - .filter_map(std::result::Result::ok) // Filter out buggy events - .filter(|(_, pdu)| { - services() - .rooms - .state_accessor - .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) - .unwrap_or(false) - }) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` - .collect(); + resp.start = from.stringify(); + resp.end = next_token.map(|count| count.stringify()); + resp.chunk = events_after; + }, + ruma::api::Direction::Backward => { + services().rooms.timeline.backfill_if_required(&body.room_id, from).await?; + let events_before: Vec<_> = services() + .rooms + .timeline + .pdus_until(sender_user, &body.room_id, from)? + .take(limit) + .filter_map(std::result::Result::ok) // Filter out buggy events + .filter(|(_, pdu)| { + services() + .rooms + .state_accessor + .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) + .unwrap_or(false) + }) + .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .collect(); - for (_, event) in &events_before { - /* TODO: Remove this when these are resolved: - * https://github.com/vector-im/element-android/issues/3417 - * https://github.com/vector-im/element-web/issues/21034 - if !services().rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - &body.room_id, - &event.sender, - )? { - lazy_loaded.insert(event.sender.clone()); - } - */ - lazy_loaded.insert(event.sender.clone()); - } + for (_, event) in &events_before { + /* TODO: Remove this when these are resolved: + * https://github.com/vector-im/element-android/issues/3417 + * https://github.com/vector-im/element-web/issues/21034 + if !services().rooms.lazy_loading.lazy_load_was_sent_before( + sender_user, + sender_device, + &body.room_id, + &event.sender, + )? { + lazy_loaded.insert(event.sender.clone()); + } + */ + lazy_loaded.insert(event.sender.clone()); + } - next_token = events_before.last().map(|(count, _)| count).copied(); + next_token = events_before.last().map(|(count, _)| count).copied(); - let events_before: Vec<_> = events_before - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(); + let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(); - resp.start = from.stringify(); - resp.end = next_token.map(|count| count.stringify()); - resp.chunk = events_before; - } - } + resp.start = from.stringify(); + resp.end = next_token.map(|count| count.stringify()); + resp.chunk = events_before; + }, + } - resp.state = Vec::new(); - for ll_id in &lazy_loaded { - if let Some(member_event) = services().rooms.state_accessor.room_state_get( - &body.room_id, - &StateEventType::RoomMember, - ll_id.as_str(), - )? { - resp.state.push(member_event.to_state_event()); - } - } + resp.state = Vec::new(); + for ll_id in &lazy_loaded { + if let Some(member_event) = services().rooms.state_accessor.room_state_get( + &body.room_id, + &StateEventType::RoomMember, + ll_id.as_str(), + )? { + resp.state.push(member_event.to_state_event()); + } + } - // TODO: enable again when we are sure clients can handle it - /* - if let Some(next_token) = next_token { - services().rooms.lazy_loading.lazy_load_mark_sent( - sender_user, - sender_device, - &body.room_id, - lazy_loaded, - next_token, - ); - } - */ + // TODO: enable again when we are sure clients can handle it + /* + if let Some(next_token) = next_token { + services().rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + &body.room_id, + lazy_loaded, + next_token, + ); + } + */ - Ok(resp) + Ok(resp) } diff --git a/src/api/client_server/presence.rs b/src/api/client_server/presence.rs index 3d56e6f3..201b8ad8 100644 --- a/src/api/client_server/presence.rs +++ b/src/api/client_server/presence.rs @@ -1,38 +1,35 @@ -use crate::{services, Error, Result, Ruma}; -use ruma::api::client::{ - error::ErrorKind, - presence::{get_presence, set_presence}, -}; use std::time::Duration; +use ruma::api::client::{ + error::ErrorKind, + presence::{get_presence, set_presence}, +}; + +use crate::{services, Error, Result, Ruma}; + /// # `PUT /_matrix/client/r0/presence/{userId}/status` /// /// Sets the presence state of the sender user. -pub async fn set_presence_route( - body: Ruma, -) -> Result { - if !services().globals.allow_local_presence() { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Presence is disabled on this server", - )); - } +pub async fn set_presence_route(body: Ruma) -> Result { + if !services().globals.allow_local_presence() { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Presence is disabled on this server")); + } - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for room_id in services().rooms.state_cache.rooms_joined(sender_user) { - let room_id = room_id?; + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + for room_id in services().rooms.state_cache.rooms_joined(sender_user) { + let room_id = room_id?; - services().rooms.edus.presence.set_presence( - &room_id, - sender_user, - body.presence.clone(), - None, - None, - body.status_msg.clone(), - )?; - } + services().rooms.edus.presence.set_presence( + &room_id, + sender_user, + body.presence.clone(), + None, + None, + body.status_msg.clone(), + )?; + } - Ok(set_presence::v3::Response {}) + Ok(set_presence::v3::Response {}) } /// # `GET /_matrix/client/r0/presence/{userId}/status` @@ -40,53 +37,36 @@ pub async fn set_presence_route( /// Gets the presence state of the given user. /// /// - Only works if you share a room with the user -pub async fn get_presence_route( - body: Ruma, -) -> Result { - if !services().globals.allow_local_presence() { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Presence is disabled on this server", - )); - } +pub async fn get_presence_route(body: Ruma) -> Result { + if !services().globals.allow_local_presence() { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Presence is disabled on this server")); + } - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let mut presence_event = None; + let mut presence_event = None; - for room_id in services() - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? - { - let room_id = room_id?; + for room_id in services().rooms.user.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? { + let room_id = room_id?; - if let Some(presence) = services() - .rooms - .edus - .presence - .get_presence(&room_id, sender_user)? - { - presence_event = Some(presence); - break; - } - } + if let Some(presence) = services().rooms.edus.presence.get_presence(&room_id, sender_user)? { + presence_event = Some(presence); + break; + } + } - if let Some(presence) = presence_event { - Ok(get_presence::v3::Response { - // TODO: Should ruma just use the presenceeventcontent type here? - status_msg: presence.content.status_msg, - currently_active: presence.content.currently_active, - last_active_ago: presence - .content - .last_active_ago - .map(|millis| Duration::from_millis(millis.into())), - presence: presence.content.presence, - }) - } else { - Err(Error::BadRequest( - ErrorKind::NotFound, - "Presence state for this user was not found", - )) - } + if let Some(presence) = presence_event { + Ok(get_presence::v3::Response { + // TODO: Should ruma just use the presenceeventcontent type here? + status_msg: presence.content.status_msg, + currently_active: presence.content.currently_active, + last_active_ago: presence.content.last_active_ago.map(|millis| Duration::from_millis(millis.into())), + presence: presence.content.presence, + }) + } else { + Err(Error::BadRequest( + ErrorKind::NotFound, + "Presence state for this user was not found", + )) + } } diff --git a/src/api/client_server/profile.rs b/src/api/client_server/profile.rs index 74e2a4d4..5aa2be73 100644 --- a/src/api/client_server/profile.rs +++ b/src/api/client_server/profile.rs @@ -1,17 +1,15 @@ use std::sync::Arc; use ruma::{ - api::{ - client::{ - error::ErrorKind, - profile::{ - get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name, - }, - }, - federation, - }, - events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType}, - presence::PresenceState, + api::{ + client::{ + error::ErrorKind, + profile::{get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name}, + }, + federation, + }, + events::{room::member::RoomMemberEventContent, StateEventType, TimelineEventType}, + presence::PresenceState, }; use serde_json::value::to_raw_value; @@ -23,87 +21,62 @@ use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma}; /// /// - Also makes sure other users receive the update using presence EDUs pub async fn set_displayname_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .users - .set_displayname(sender_user, body.displayname.clone()) - .await?; + services().users.set_displayname(sender_user, body.displayname.clone()).await?; - // Send a new membership event and presence update into all joined rooms - let all_rooms_joined: Vec<_> = services() - .rooms - .state_cache - .rooms_joined(sender_user) - .filter_map(std::result::Result::ok) - .map(|room_id| { - Ok::<_, Error>(( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - displayname: body.displayname.clone(), - ..serde_json::from_str( - services() - .rooms - .state_accessor - .room_state_get( - &room_id, - &StateEventType::RoomMember, - sender_user.as_str(), - )? - .ok_or_else(|| { - Error::bad_database( - "Tried to send displayname update for user not in the \ - room.", - ) - })? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Database contains invalid PDU."))? - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - }, - room_id, - )) - }) - .filter_map(std::result::Result::ok) - .collect(); + // Send a new membership event and presence update into all joined rooms + let all_rooms_joined: Vec<_> = services() + .rooms + .state_cache + .rooms_joined(sender_user) + .filter_map(std::result::Result::ok) + .map(|room_id| { + Ok::<_, Error>(( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + displayname: body.displayname.clone(), + ..serde_json::from_str( + services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomMember, sender_user.as_str())? + .ok_or_else(|| { + Error::bad_database("Tried to send displayname update for user not in the room.") + })? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Database contains invalid PDU."))? + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + }, + room_id, + )) + }) + .filter_map(std::result::Result::ok) + .collect(); - for (pdu_builder, room_id) in all_rooms_joined { - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + for (pdu_builder, room_id) in all_rooms_joined { + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - let _ = services() - .rooms - .timeline - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) - .await; - } + let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await; + } - if services().globals.allow_local_presence() { - // Presence update - services() - .rooms - .edus - .presence - .ping_presence(sender_user, PresenceState::Online)?; - } + if services().globals.allow_local_presence() { + // Presence update + services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?; + } - Ok(set_display_name::v3::Response {}) + Ok(set_display_name::v3::Response {}) } /// # `GET /_matrix/client/v3/profile/{userId}/displayname` @@ -113,55 +86,44 @@ pub async fn set_displayname_route( /// - If user is on another server and we do not have a local copy already /// fetch displayname over federation pub async fn get_displayname_route( - body: Ruma, + body: Ruma, ) -> Result { - if body.user_id.server_name() != services().globals.server_name() { - // Create and update our local copy of the user - if let Ok(response) = services() - .sending - .send_federation_request( - body.user_id.server_name(), - federation::query::get_profile_information::v1::Request { - user_id: body.user_id.clone(), - field: None, // we want the full user's profile to update locally too - }, - ) - .await - { - if !services().users.exists(&body.user_id)? { - services().users.create(&body.user_id, None)?; - } + if body.user_id.server_name() != services().globals.server_name() { + // Create and update our local copy of the user + if let Ok(response) = services() + .sending + .send_federation_request( + body.user_id.server_name(), + federation::query::get_profile_information::v1::Request { + user_id: body.user_id.clone(), + field: None, // we want the full user's profile to update locally too + }, + ) + .await + { + if !services().users.exists(&body.user_id)? { + services().users.create(&body.user_id, None)?; + } - services() - .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; - services() - .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; - services() - .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + services().users.set_displayname(&body.user_id, response.displayname.clone()).await?; + services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?; + services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?; - return Ok(get_display_name::v3::Response { - displayname: response.displayname, - }); - } - } + return Ok(get_display_name::v3::Response { + displayname: response.displayname, + }); + } + } - if !services().users.exists(&body.user_id)? { - // Return 404 if this user doesn't exist and we couldn't fetch it over federation - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Profile was not found.", - )); - } + if !services().users.exists(&body.user_id)? { + // Return 404 if this user doesn't exist and we couldn't fetch it over + // federation + return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); + } - Ok(get_display_name::v3::Response { - displayname: services().users.displayname(&body.user_id)?, - }) + Ok(get_display_name::v3::Response { + displayname: services().users.displayname(&body.user_id)?, + }) } /// # `PUT /_matrix/client/r0/profile/{userId}/avatar_url` @@ -169,93 +131,63 @@ pub async fn get_displayname_route( /// Updates the avatar_url and blurhash. /// /// - Also makes sure other users receive the update using presence EDUs -pub async fn set_avatar_url_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn set_avatar_url_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .users - .set_avatar_url(sender_user, body.avatar_url.clone()) - .await?; + services().users.set_avatar_url(sender_user, body.avatar_url.clone()).await?; - services() - .users - .set_blurhash(sender_user, body.blurhash.clone()) - .await?; + services().users.set_blurhash(sender_user, body.blurhash.clone()).await?; - // Send a new membership event and presence update into all joined rooms - let all_joined_rooms: Vec<_> = services() - .rooms - .state_cache - .rooms_joined(sender_user) - .filter_map(std::result::Result::ok) - .map(|room_id| { - Ok::<_, Error>(( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - avatar_url: body.avatar_url.clone(), - ..serde_json::from_str( - services() - .rooms - .state_accessor - .room_state_get( - &room_id, - &StateEventType::RoomMember, - sender_user.as_str(), - )? - .ok_or_else(|| { - Error::bad_database( - "Tried to send displayname update for user not in the \ - room.", - ) - })? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Database contains invalid PDU."))? - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - }, - room_id, - )) - }) - .filter_map(std::result::Result::ok) - .collect(); + // Send a new membership event and presence update into all joined rooms + let all_joined_rooms: Vec<_> = services() + .rooms + .state_cache + .rooms_joined(sender_user) + .filter_map(std::result::Result::ok) + .map(|room_id| { + Ok::<_, Error>(( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + avatar_url: body.avatar_url.clone(), + ..serde_json::from_str( + services() + .rooms + .state_accessor + .room_state_get(&room_id, &StateEventType::RoomMember, sender_user.as_str())? + .ok_or_else(|| { + Error::bad_database("Tried to send displayname update for user not in the room.") + })? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Database contains invalid PDU."))? + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + }, + room_id, + )) + }) + .filter_map(std::result::Result::ok) + .collect(); - for (pdu_builder, room_id) in all_joined_rooms { - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + for (pdu_builder, room_id) in all_joined_rooms { + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - let _ = services() - .rooms - .timeline - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) - .await; - } + let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await; + } - if services().globals.allow_local_presence() { - // Presence update - services() - .rooms - .edus - .presence - .ping_presence(sender_user, PresenceState::Online)?; - } + if services().globals.allow_local_presence() { + // Presence update + services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?; + } - Ok(set_avatar_url::v3::Response {}) + Ok(set_avatar_url::v3::Response {}) } /// # `GET /_matrix/client/v3/profile/{userId}/avatar_url` @@ -264,58 +196,45 @@ pub async fn set_avatar_url_route( /// /// - If user is on another server and we do not have a local copy already /// fetch avatar_url and blurhash over federation -pub async fn get_avatar_url_route( - body: Ruma, -) -> Result { - if body.user_id.server_name() != services().globals.server_name() { - // Create and update our local copy of the user - if let Ok(response) = services() - .sending - .send_federation_request( - body.user_id.server_name(), - federation::query::get_profile_information::v1::Request { - user_id: body.user_id.clone(), - field: None, // we want the full user's profile to update locally as well - }, - ) - .await - { - if !services().users.exists(&body.user_id)? { - services().users.create(&body.user_id, None)?; - } +pub async fn get_avatar_url_route(body: Ruma) -> Result { + if body.user_id.server_name() != services().globals.server_name() { + // Create and update our local copy of the user + if let Ok(response) = services() + .sending + .send_federation_request( + body.user_id.server_name(), + federation::query::get_profile_information::v1::Request { + user_id: body.user_id.clone(), + field: None, // we want the full user's profile to update locally as well + }, + ) + .await + { + if !services().users.exists(&body.user_id)? { + services().users.create(&body.user_id, None)?; + } - services() - .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; - services() - .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; - services() - .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + services().users.set_displayname(&body.user_id, response.displayname.clone()).await?; + services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?; + services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?; - return Ok(get_avatar_url::v3::Response { - avatar_url: response.avatar_url, - blurhash: response.blurhash, - }); - } - } + return Ok(get_avatar_url::v3::Response { + avatar_url: response.avatar_url, + blurhash: response.blurhash, + }); + } + } - if !services().users.exists(&body.user_id)? { - // Return 404 if this user doesn't exist and we couldn't fetch it over federation - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Profile was not found.", - )); - } + if !services().users.exists(&body.user_id)? { + // Return 404 if this user doesn't exist and we couldn't fetch it over + // federation + return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); + } - Ok(get_avatar_url::v3::Response { - avatar_url: services().users.avatar_url(&body.user_id)?, - blurhash: services().users.blurhash(&body.user_id)?, - }) + Ok(get_avatar_url::v3::Response { + avatar_url: services().users.avatar_url(&body.user_id)?, + blurhash: services().users.blurhash(&body.user_id)?, + }) } /// # `GET /_matrix/client/v3/profile/{userId}` @@ -324,58 +243,45 @@ pub async fn get_avatar_url_route( /// /// - If user is on another server and we do not have a local copy already, /// fetch profile over federation. -pub async fn get_profile_route( - body: Ruma, -) -> Result { - if body.user_id.server_name() != services().globals.server_name() { - // Create and update our local copy of the user - if let Ok(response) = services() - .sending - .send_federation_request( - body.user_id.server_name(), - federation::query::get_profile_information::v1::Request { - user_id: body.user_id.clone(), - field: None, - }, - ) - .await - { - if !services().users.exists(&body.user_id)? { - services().users.create(&body.user_id, None)?; - } +pub async fn get_profile_route(body: Ruma) -> Result { + if body.user_id.server_name() != services().globals.server_name() { + // Create and update our local copy of the user + if let Ok(response) = services() + .sending + .send_federation_request( + body.user_id.server_name(), + federation::query::get_profile_information::v1::Request { + user_id: body.user_id.clone(), + field: None, + }, + ) + .await + { + if !services().users.exists(&body.user_id)? { + services().users.create(&body.user_id, None)?; + } - services() - .users - .set_displayname(&body.user_id, response.displayname.clone()) - .await?; - services() - .users - .set_avatar_url(&body.user_id, response.avatar_url.clone()) - .await?; - services() - .users - .set_blurhash(&body.user_id, response.blurhash.clone()) - .await?; + services().users.set_displayname(&body.user_id, response.displayname.clone()).await?; + services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?; + services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?; - return Ok(get_profile::v3::Response { - displayname: response.displayname, - avatar_url: response.avatar_url, - blurhash: response.blurhash, - }); - } - } + return Ok(get_profile::v3::Response { + displayname: response.displayname, + avatar_url: response.avatar_url, + blurhash: response.blurhash, + }); + } + } - if !services().users.exists(&body.user_id)? { - // Return 404 if this user doesn't exist and we couldn't fetch it over federation - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Profile was not found.", - )); - } + if !services().users.exists(&body.user_id)? { + // Return 404 if this user doesn't exist and we couldn't fetch it over + // federation + return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found.")); + } - Ok(get_profile::v3::Response { - avatar_url: services().users.avatar_url(&body.user_id)?, - blurhash: services().users.blurhash(&body.user_id)?, - displayname: services().users.displayname(&body.user_id)?, - }) + Ok(get_profile::v3::Response { + avatar_url: services().users.avatar_url(&body.user_id)?, + blurhash: services().users.blurhash(&body.user_id)?, + displayname: services().users.displayname(&body.user_id)?, + }) } diff --git a/src/api/client_server/push.rs b/src/api/client_server/push.rs index d2c35668..a03c7db8 100644 --- a/src/api/client_server/push.rs +++ b/src/api/client_server/push.rs @@ -1,417 +1,320 @@ -use crate::{services, Error, Result, Ruma}; use ruma::{ - api::client::{ - error::ErrorKind, - push::{ - delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, - get_pushrules_all, set_pusher, set_pushrule, set_pushrule_actions, - set_pushrule_enabled, RuleScope, - }, - }, - events::{push_rules::PushRulesEvent, GlobalAccountDataEventType}, - push::{InsertPushRuleError, RemovePushRuleError}, + api::client::{ + error::ErrorKind, + push::{ + delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, get_pushrules_all, + set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, RuleScope, + }, + }, + events::{push_rules::PushRulesEvent, GlobalAccountDataEventType}, + push::{InsertPushRuleError, RemovePushRuleError}, }; +use crate::{services, Error, Result, Ruma}; + /// # `GET /_matrix/client/r0/pushrules` /// /// Retrieves the push rules event for this user. pub async fn get_pushrules_all_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services() - .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "PushRules event not found.", - ))?; + let event = services() + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + let account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; - Ok(get_pushrules_all::v3::Response { - global: account_data.global, - }) + Ok(get_pushrules_all::v3::Response { + global: account_data.global, + }) } /// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// /// Retrieves a single specified push rule for this user. -pub async fn get_pushrule_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn get_pushrule_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services() - .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "PushRules event not found.", - ))?; + let event = services() + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + let account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; - let rule = account_data - .global - .get(body.kind.clone(), &body.rule_id) - .map(Into::into); + let rule = account_data.global.get(body.kind.clone(), &body.rule_id).map(Into::into); - if let Some(rule) = rule { - Ok(get_pushrule::v3::Response { rule }) - } else { - Err(Error::BadRequest( - ErrorKind::NotFound, - "Push rule not found.", - )) - } + if let Some(rule) = rule { + Ok(get_pushrule::v3::Response { + rule, + }) + } else { + Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")) + } } /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// /// Creates a single specified push rule for this user. -pub async fn set_pushrule_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let body = body.body; +pub async fn set_pushrule_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let body = body.body; - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } + if body.scope != RuleScope::Global { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Scopes other than 'global' are not supported.", + )); + } - let event = services() - .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "PushRules event not found.", - ))?; + let event = services() + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + let mut account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; - if let Err(error) = account_data.content.global.insert( - body.rule.clone(), - body.after.as_deref(), - body.before.as_deref(), - ) { - let err = match error { - InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest( - ErrorKind::InvalidParam, - "Rule IDs starting with a dot are reserved for server-default rules.", - ), - InsertPushRuleError::InvalidRuleId => Error::BadRequest( - ErrorKind::InvalidParam, - "Rule ID containing invalid characters.", - ), - InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest( - ErrorKind::InvalidParam, - "Can't place a push rule relatively to a server-default rule.", - ), - InsertPushRuleError::UnknownRuleId => Error::BadRequest( - ErrorKind::NotFound, - "The before or after rule could not be found.", - ), - InsertPushRuleError::BeforeHigherThanAfter => Error::BadRequest( - ErrorKind::InvalidParam, - "The before rule has a higher priority than the after rule.", - ), - _ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."), - }; + if let Err(error) = + account_data.content.global.insert(body.rule.clone(), body.after.as_deref(), body.before.as_deref()) + { + let err = match error { + InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest( + ErrorKind::InvalidParam, + "Rule IDs starting with a dot are reserved for server-default rules.", + ), + InsertPushRuleError::InvalidRuleId => { + Error::BadRequest(ErrorKind::InvalidParam, "Rule ID containing invalid characters.") + }, + InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest( + ErrorKind::InvalidParam, + "Can't place a push rule relatively to a server-default rule.", + ), + InsertPushRuleError::UnknownRuleId => { + Error::BadRequest(ErrorKind::NotFound, "The before or after rule could not be found.") + }, + InsertPushRuleError::BeforeHigherThanAfter => Error::BadRequest( + ErrorKind::InvalidParam, + "The before rule has a higher priority than the after rule.", + ), + _ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."), + }; - return Err(err); - } + return Err(err); + } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services().account_data.update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + )?; - Ok(set_pushrule::v3::Response {}) + Ok(set_pushrule::v3::Response {}) } /// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions` /// /// Gets the actions of a single specified push rule for this user. pub async fn get_pushrule_actions_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } + if body.scope != RuleScope::Global { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Scopes other than 'global' are not supported.", + )); + } - let event = services() - .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "PushRules event not found.", - ))?; + let event = services() + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))? - .content; + let account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))? + .content; - let global = account_data.global; - let actions = global - .get(body.kind.clone(), &body.rule_id) - .map(|rule| rule.actions().to_owned()) - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Push rule not found.", - ))?; + let global = account_data.global; + let actions = global + .get(body.kind.clone(), &body.rule_id) + .map(|rule| rule.actions().to_owned()) + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?; - Ok(get_pushrule_actions::v3::Response { actions }) + Ok(get_pushrule_actions::v3::Response { + actions, + }) } /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions` /// /// Sets the actions of a single specified push rule for this user. pub async fn set_pushrule_actions_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } + if body.scope != RuleScope::Global { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Scopes other than 'global' are not supported.", + )); + } - let event = services() - .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "PushRules event not found.", - ))?; + let event = services() + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + let mut account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; - if account_data - .content - .global - .set_actions(body.kind.clone(), &body.rule_id, body.actions.clone()) - .is_err() - { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Push rule not found.", - )); - } + if account_data.content.global.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone()).is_err() { + return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); + } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services().account_data.update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + )?; - Ok(set_pushrule_actions::v3::Response {}) + Ok(set_pushrule_actions::v3::Response {}) } /// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled` /// /// Gets the enabled status of a single specified push rule for this user. pub async fn get_pushrule_enabled_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } + if body.scope != RuleScope::Global { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Scopes other than 'global' are not supported.", + )); + } - let event = services() - .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "PushRules event not found.", - ))?; + let event = services() + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - let account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + let account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; - let global = account_data.content.global; - let enabled = global - .get(body.kind.clone(), &body.rule_id) - .map(ruma::push::AnyPushRuleRef::enabled) - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Push rule not found.", - ))?; + let global = account_data.content.global; + let enabled = global + .get(body.kind.clone(), &body.rule_id) + .map(ruma::push::AnyPushRuleRef::enabled) + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?; - Ok(get_pushrule_enabled::v3::Response { enabled }) + Ok(get_pushrule_enabled::v3::Response { + enabled, + }) } /// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled` /// /// Sets the enabled status of a single specified push rule for this user. pub async fn set_pushrule_enabled_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } + if body.scope != RuleScope::Global { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Scopes other than 'global' are not supported.", + )); + } - let event = services() - .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "PushRules event not found.", - ))?; + let event = services() + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + let mut account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; - if account_data - .content - .global - .set_enabled(body.kind.clone(), &body.rule_id, body.enabled) - .is_err() - { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Push rule not found.", - )); - } + if account_data.content.global.set_enabled(body.kind.clone(), &body.rule_id, body.enabled).is_err() { + return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")); + } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services().account_data.update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + )?; - Ok(set_pushrule_enabled::v3::Response {}) + Ok(set_pushrule_enabled::v3::Response {}) } /// # `DELETE /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}` /// /// Deletes a single specified push rule for this user. -pub async fn delete_pushrule_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn delete_pushrule_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if body.scope != RuleScope::Global { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Scopes other than 'global' are not supported.", - )); - } + if body.scope != RuleScope::Global { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Scopes other than 'global' are not supported.", + )); + } - let event = services() - .account_data - .get( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "PushRules event not found.", - ))?; + let event = services() + .account_data + .get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?; - let mut account_data = serde_json::from_str::(event.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db."))?; + let mut account_data = serde_json::from_str::(event.get()) + .map_err(|_| Error::bad_database("Invalid account data event in db."))?; - if let Err(error) = account_data - .content - .global - .remove(body.kind.clone(), &body.rule_id) - { - let err = match error { - RemovePushRuleError::ServerDefault => Error::BadRequest( - ErrorKind::InvalidParam, - "Cannot delete a server-default pushrule.", - ), - RemovePushRuleError::NotFound => { - Error::BadRequest(ErrorKind::NotFound, "Push rule not found.") - } - _ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."), - }; + if let Err(error) = account_data.content.global.remove(body.kind.clone(), &body.rule_id) { + let err = match error { + RemovePushRuleError::ServerDefault => { + Error::BadRequest(ErrorKind::InvalidParam, "Cannot delete a server-default pushrule.") + }, + RemovePushRuleError::NotFound => Error::BadRequest(ErrorKind::NotFound, "Push rule not found."), + _ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."), + }; - return Err(err); - } + return Err(err); + } - services().account_data.update( - None, - sender_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; + services().account_data.update( + None, + sender_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + )?; - Ok(delete_pushrule::v3::Response {}) + Ok(delete_pushrule::v3::Response {}) } /// # `GET /_matrix/client/r0/pushers` /// /// Gets all currently active pushers for the sender user. -pub async fn get_pushers_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn get_pushers_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - Ok(get_pushers::v3::Response { - pushers: services().pusher.get_pushers(sender_user)?, - }) + Ok(get_pushers::v3::Response { + pushers: services().pusher.get_pushers(sender_user)?, + }) } /// # `POST /_matrix/client/r0/pushers/set` @@ -419,14 +322,10 @@ pub async fn get_pushers_route( /// Adds a pusher for the sender user. /// /// - TODO: Handle `append` -pub async fn set_pushers_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn set_pushers_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - services() - .pusher - .set_pusher(sender_user, body.action.clone())?; + services().pusher.set_pusher(sender_user, body.action.clone())?; - Ok(set_pusher::v3::Response::default()) + Ok(set_pusher::v3::Response::default()) } diff --git a/src/api/client_server/read_marker.rs b/src/api/client_server/read_marker.rs index d01a3a7f..182748d6 100644 --- a/src/api/client_server/read_marker.rs +++ b/src/api/client_server/read_marker.rs @@ -1,182 +1,161 @@ -use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma}; -use ruma::{ - api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, - events::{ - receipt::{ReceiptThread, ReceiptType}, - RoomAccountDataEventType, - }, - MilliSecondsSinceUnixEpoch, -}; use std::collections::BTreeMap; +use ruma::{ + api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, + events::{ + receipt::{ReceiptThread, ReceiptType}, + RoomAccountDataEventType, + }, + MilliSecondsSinceUnixEpoch, +}; + +use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma}; + /// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers` /// /// Sets different types of read markers. /// /// - Updates fully-read account data event to `fully_read` -/// - If `read_receipt` is set: Update private marker and public read receipt EDU -pub async fn set_read_marker_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +/// - If `read_receipt` is set: Update private marker and public read receipt +/// EDU +pub async fn set_read_marker_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if let Some(fully_read) = &body.fully_read { - let fully_read_event = ruma::events::fully_read::FullyReadEvent { - content: ruma::events::fully_read::FullyReadEventContent { - event_id: fully_read.clone(), - }, - }; - services().account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; - } + if let Some(fully_read) = &body.fully_read { + let fully_read_event = ruma::events::fully_read::FullyReadEvent { + content: ruma::events::fully_read::FullyReadEventContent { + event_id: fully_read.clone(), + }, + }; + services().account_data.update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + )?; + } - if body.private_read_receipt.is_some() || body.read_receipt.is_some() { - services() - .rooms - .user - .reset_notification_counts(sender_user, &body.room_id)?; - } + if body.private_read_receipt.is_some() || body.read_receipt.is_some() { + services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?; + } - if let Some(event) = &body.private_read_receipt { - let count = services() - .rooms - .timeline - .get_pdu_count(event)? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event does not exist.", - ))?; - let count = match count { - PduCount::Backfilled(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Read receipt is in backfilled timeline", - )) - } - PduCount::Normal(c) => c, - }; - services() - .rooms - .edus - .read_receipt - .private_read_set(&body.room_id, sender_user, count)?; - } + if let Some(event) = &body.private_read_receipt { + let count = services() + .rooms + .timeline + .get_pdu_count(event)? + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + let count = match count { + PduCount::Backfilled(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Read receipt is in backfilled timeline", + )) + }, + PduCount::Normal(c) => c, + }; + services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?; + } - if let Some(event) = &body.read_receipt { - let mut user_receipts = BTreeMap::new(); - user_receipts.insert( - sender_user.clone(), - ruma::events::receipt::Receipt { - ts: Some(MilliSecondsSinceUnixEpoch::now()), - thread: ReceiptThread::Unthreaded, - }, - ); + if let Some(event) = &body.read_receipt { + let mut user_receipts = BTreeMap::new(); + user_receipts.insert( + sender_user.clone(), + ruma::events::receipt::Receipt { + ts: Some(MilliSecondsSinceUnixEpoch::now()), + thread: ReceiptThread::Unthreaded, + }, + ); - let mut receipts = BTreeMap::new(); - receipts.insert(ReceiptType::Read, user_receipts); + let mut receipts = BTreeMap::new(); + receipts.insert(ReceiptType::Read, user_receipts); - let mut receipt_content = BTreeMap::new(); - receipt_content.insert(event.to_owned(), receipts); + let mut receipt_content = BTreeMap::new(); + receipt_content.insert(event.to_owned(), receipts); - services().rooms.edus.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; - } + services().rooms.edus.read_receipt.readreceipt_update( + sender_user, + &body.room_id, + ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + )?; + } - Ok(set_read_marker::v3::Response {}) + Ok(set_read_marker::v3::Response {}) } /// # `POST /_matrix/client/r0/rooms/{roomId}/receipt/{receiptType}/{eventId}` /// /// Sets private read marker and public read receipt EDU. -pub async fn create_receipt_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn create_receipt_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if matches!( - &body.receipt_type, - create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate - ) { - services() - .rooms - .user - .reset_notification_counts(sender_user, &body.room_id)?; - } + if matches!( + &body.receipt_type, + create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate + ) { + services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?; + } - match body.receipt_type { - create_receipt::v3::ReceiptType::FullyRead => { - let fully_read_event = ruma::events::fully_read::FullyReadEvent { - content: ruma::events::fully_read::FullyReadEventContent { - event_id: body.event_id.clone(), - }, - }; - services().account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::FullyRead, - &serde_json::to_value(fully_read_event).expect("to json value always works"), - )?; - } - create_receipt::v3::ReceiptType::Read => { - let mut user_receipts = BTreeMap::new(); - user_receipts.insert( - sender_user.clone(), - ruma::events::receipt::Receipt { - ts: Some(MilliSecondsSinceUnixEpoch::now()), - thread: ReceiptThread::Unthreaded, - }, - ); - let mut receipts = BTreeMap::new(); - receipts.insert(ReceiptType::Read, user_receipts); + match body.receipt_type { + create_receipt::v3::ReceiptType::FullyRead => { + let fully_read_event = ruma::events::fully_read::FullyReadEvent { + content: ruma::events::fully_read::FullyReadEventContent { + event_id: body.event_id.clone(), + }, + }; + services().account_data.update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::FullyRead, + &serde_json::to_value(fully_read_event).expect("to json value always works"), + )?; + }, + create_receipt::v3::ReceiptType::Read => { + let mut user_receipts = BTreeMap::new(); + user_receipts.insert( + sender_user.clone(), + ruma::events::receipt::Receipt { + ts: Some(MilliSecondsSinceUnixEpoch::now()), + thread: ReceiptThread::Unthreaded, + }, + ); + let mut receipts = BTreeMap::new(); + receipts.insert(ReceiptType::Read, user_receipts); - let mut receipt_content = BTreeMap::new(); - receipt_content.insert(body.event_id.clone(), receipts); + let mut receipt_content = BTreeMap::new(); + receipt_content.insert(body.event_id.clone(), receipts); - services().rooms.edus.read_receipt.readreceipt_update( - sender_user, - &body.room_id, - ruma::events::receipt::ReceiptEvent { - content: ruma::events::receipt::ReceiptEventContent(receipt_content), - room_id: body.room_id.clone(), - }, - )?; - } - create_receipt::v3::ReceiptType::ReadPrivate => { - let count = services() - .rooms - .timeline - .get_pdu_count(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event does not exist.", - ))?; - let count = match count { - PduCount::Backfilled(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Read receipt is in backfilled timeline", - )) - } - PduCount::Normal(c) => c, - }; - services().rooms.edus.read_receipt.private_read_set( - &body.room_id, - sender_user, - count, - )?; - } - _ => return Err(Error::bad_database("Unsupported receipt type")), - } + services().rooms.edus.read_receipt.readreceipt_update( + sender_user, + &body.room_id, + ruma::events::receipt::ReceiptEvent { + content: ruma::events::receipt::ReceiptEventContent(receipt_content), + room_id: body.room_id.clone(), + }, + )?; + }, + create_receipt::v3::ReceiptType::ReadPrivate => { + let count = services() + .rooms + .timeline + .get_pdu_count(&body.event_id)? + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + let count = match count { + PduCount::Backfilled(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Read receipt is in backfilled timeline", + )) + }, + PduCount::Normal(c) => c, + }; + services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?; + }, + _ => return Err(Error::bad_database("Unsupported receipt type")), + } - Ok(create_receipt::v3::Response {}) + Ok(create_receipt::v3::Response {}) } diff --git a/src/api/client_server/redact.rs b/src/api/client_server/redact.rs index a438d248..674a67c6 100644 --- a/src/api/client_server/redact.rs +++ b/src/api/client_server/redact.rs @@ -1,58 +1,51 @@ use std::sync::Arc; -use crate::{service::pdu::PduBuilder, services, Result, Ruma}; use ruma::{ - api::client::redact::redact_event, - events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, + api::client::redact::redact_event, + events::{room::redaction::RoomRedactionEventContent, TimelineEventType}, }; - use serde_json::value::to_raw_value; +use crate::{service::pdu::PduBuilder, services, Result, Ruma}; + /// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}` /// /// Tries to send a redaction event into the room. /// /// - TODO: Handle txn id -pub async fn redact_event_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let body = body.body; +pub async fn redact_event_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let body = body.body; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - let event_id = services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomRedaction, - content: to_raw_value(&RoomRedactionEventContent { - redacts: Some(body.event_id.clone()), - reason: body.reason.clone(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: Some(body.event_id.into()), - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; + let event_id = services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomRedaction, + content: to_raw_value(&RoomRedactionEventContent { + redacts: Some(body.event_id.clone()), + reason: body.reason.clone(), + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: Some(body.event_id.into()), + }, + sender_user, + &body.room_id, + &state_lock, + ) + .await?; - drop(state_lock); + drop(state_lock); - let event_id = (*event_id).to_owned(); - Ok(redact_event::v3::Response { event_id }) + let event_id = (*event_id).to_owned(); + Ok(redact_event::v3::Response { + event_id, + }) } diff --git a/src/api/client_server/relations.rs b/src/api/client_server/relations.rs index 124f1310..853d6011 100644 --- a/src/api/client_server/relations.rs +++ b/src/api/client_server/relations.rs @@ -1,146 +1,113 @@ use ruma::api::client::relations::{ - get_relating_events, get_relating_events_with_rel_type, - get_relating_events_with_rel_type_and_event_type, + get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, }; use crate::{service::rooms::timeline::PduCount, services, Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` pub async fn get_relating_events_with_rel_type_and_event_type_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, - None => match ruma::api::Direction::Backward { - // TODO: fix ruma so `body.dir` exists - ruma::api::Direction::Forward => PduCount::min(), - ruma::api::Direction::Backward => PduCount::max(), - }, - }; + let from = match body.from.clone() { + Some(from) => PduCount::try_from_string(&from)?, + None => match ruma::api::Direction::Backward { + // TODO: fix ruma so `body.dir` exists + ruma::api::Direction::Forward => PduCount::min(), + ruma::api::Direction::Backward => PduCount::max(), + }, + }; - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); + let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); - // Use limit or else 10, with maximum 100 - let limit = body - .limit - .and_then(|u| u32::try_from(u).ok()) - .map_or(10_usize, |u| u as usize) - .min(100); + // Use limit or else 10, with maximum 100 + let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100); - let res = services() - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - Some(body.event_type.clone()), - Some(body.rel_type.clone()), - from, - to, - limit, - )?; + let res = services().rooms.pdu_metadata.paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + Some(body.event_type.clone()), + Some(body.rel_type.clone()), + from, + to, + limit, + )?; - Ok( - get_relating_events_with_rel_type_and_event_type::v1::Response { - chunk: res.chunk, - next_batch: res.next_batch, - prev_batch: res.prev_batch, - }, - ) + Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { + chunk: res.chunk, + next_batch: res.next_batch, + prev_batch: res.prev_batch, + }) } /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}` pub async fn get_relating_events_with_rel_type_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, - None => match ruma::api::Direction::Backward { - // TODO: fix ruma so `body.dir` exists - ruma::api::Direction::Forward => PduCount::min(), - ruma::api::Direction::Backward => PduCount::max(), - }, - }; + let from = match body.from.clone() { + Some(from) => PduCount::try_from_string(&from)?, + None => match ruma::api::Direction::Backward { + // TODO: fix ruma so `body.dir` exists + ruma::api::Direction::Forward => PduCount::min(), + ruma::api::Direction::Backward => PduCount::max(), + }, + }; - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); + let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); - // Use limit or else 10, with maximum 100 - let limit = body - .limit - .and_then(|u| u32::try_from(u).ok()) - .map_or(10_usize, |u| u as usize) - .min(100); + // Use limit or else 10, with maximum 100 + let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100); - let res = services() - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - None, - Some(body.rel_type.clone()), - from, - to, - limit, - )?; + let res = services().rooms.pdu_metadata.paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + None, + Some(body.rel_type.clone()), + from, + to, + limit, + )?; - Ok(get_relating_events_with_rel_type::v1::Response { - chunk: res.chunk, - next_batch: res.next_batch, - prev_batch: res.prev_batch, - }) + Ok(get_relating_events_with_rel_type::v1::Response { + chunk: res.chunk, + next_batch: res.next_batch, + prev_batch: res.prev_batch, + }) } /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}` pub async fn get_relating_events_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let from = match body.from.clone() { - Some(from) => PduCount::try_from_string(&from)?, - None => match ruma::api::Direction::Backward { - // TODO: fix ruma so `body.dir` exists - ruma::api::Direction::Forward => PduCount::min(), - ruma::api::Direction::Backward => PduCount::max(), - }, - }; + let from = match body.from.clone() { + Some(from) => PduCount::try_from_string(&from)?, + None => match ruma::api::Direction::Backward { + // TODO: fix ruma so `body.dir` exists + ruma::api::Direction::Forward => PduCount::min(), + ruma::api::Direction::Backward => PduCount::max(), + }, + }; - let to = body - .to - .as_ref() - .and_then(|t| PduCount::try_from_string(t).ok()); + let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok()); - // Use limit or else 10, with maximum 100 - let limit = body - .limit - .and_then(|u| u32::try_from(u).ok()) - .map_or(10_usize, |u| u as usize) - .min(100); + // Use limit or else 10, with maximum 100 + let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100); - services() - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - None, - None, - from, - to, - limit, - ) + services().rooms.pdu_metadata.paginate_relations_with_filter( + sender_user, + &body.room_id, + &body.event_id, + None, + None, + from, + to, + limit, + ) } diff --git a/src/api/client_server/report.rs b/src/api/client_server/report.rs index e424449b..255ca8e5 100644 --- a/src/api/client_server/report.rs +++ b/src/api/client_server/report.rs @@ -1,118 +1,112 @@ use std::time::Duration; -use crate::{services, utils::HtmlEscape, Error, Result, Ruma}; use rand::Rng; use ruma::{ - api::client::{error::ErrorKind, room::report_content}, - events::room::message, - int, + api::client::{error::ErrorKind, room::report_content}, + events::room::message, + int, }; use tokio::time::sleep; use tracing::{debug, info}; +use crate::{services, utils::HtmlEscape, Error, Result, Ruma}; + /// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}` /// /// Reports an inappropriate event to homeserver admins -/// -pub async fn report_event_route( - body: Ruma, -) -> Result { - // user authentication - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn report_event_route(body: Ruma) -> Result { + // user authentication + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - info!("Received /report request by user {}", sender_user); + info!("Received /report request by user {}", sender_user); - // check if we know about the reported event ID or if it's invalid - let pdu = match services().rooms.timeline.get_pdu(&body.event_id)? { - Some(pdu) => pdu, - _ => { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Event ID is not known to us or Event ID is invalid", - )) - } - }; + // check if we know about the reported event ID or if it's invalid + let pdu = match services().rooms.timeline.get_pdu(&body.event_id)? { + Some(pdu) => pdu, + _ => { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Event ID is not known to us or Event ID is invalid", + )) + }, + }; - // check if the room ID from the URI matches the PDU's room ID - if body.room_id != pdu.room_id { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Event ID does not belong to the reported room", - )); - } + // check if the room ID from the URI matches the PDU's room ID + if body.room_id != pdu.room_id { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "Event ID does not belong to the reported room", + )); + } - // check if reporting user is in the reporting room - if !services() - .rooms - .state_cache - .room_members(&pdu.room_id) - .filter_map(std::result::Result::ok) - .any(|user_id| user_id == *sender_user) - { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "You are not in the room you are reporting.", - )); - } + // check if reporting user is in the reporting room + if !services() + .rooms + .state_cache + .room_members(&pdu.room_id) + .filter_map(std::result::Result::ok) + .any(|user_id| user_id == *sender_user) + { + return Err(Error::BadRequest( + ErrorKind::NotFound, + "You are not in the room you are reporting.", + )); + } - // check if score is in valid range - if let Some(true) = body.score.map(|s| s > int!(0) || s < int!(-100)) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid score, must be within 0 to -100", - )); - }; + // check if score is in valid range + if let Some(true) = body.score.map(|s| s > int!(0) || s < int!(-100)) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Invalid score, must be within 0 to -100", + )); + }; - // check if report reasoning is less than or equal to 750 characters - if let Some(true) = body.reason.clone().map(|s| s.chars().count() >= 750) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Reason too long, should be 750 characters or fewer", - )); - }; + // check if report reasoning is less than or equal to 750 characters + if let Some(true) = body.reason.clone().map(|s| s.chars().count() >= 750) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Reason too long, should be 750 characters or fewer", + )); + }; - // send admin room message that we received the report with an @room ping for urgency - services() - .admin - .send_message(message::RoomMessageEventContent::text_html( - format!( - "@room Report received from: {}\n\n\ - Event ID: {}\n\ - Room ID: {}\n\ - Sent By: {}\n\n\ - Report Score: {}\n\ - Report Reason: {}", - sender_user.to_owned(), - pdu.event_id, - pdu.room_id, - pdu.sender.clone(), - body.score.unwrap_or_else(|| ruma::Int::from(0)), - body.reason.as_deref().unwrap_or("") - ), - format!( - "
@room Report received from: {0}\ + // send admin room message that we received the report with an @room ping for + // urgency + services().admin.send_message(message::RoomMessageEventContent::text_html( + format!( + "@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \ + Reason: {}", + sender_user.to_owned(), + pdu.event_id, + pdu.room_id, + pdu.sender.clone(), + body.score.unwrap_or_else(|| ruma::Int::from(0)), + body.reason.as_deref().unwrap_or("") + ), + format!( + "
@room Report received from: {0}\
  • Event Info
    • Event ID: {1}\ 🔗
    • Room ID: {2}\
    • Sent By: {3}
  • \ Report Info
    • Report Score: {4}
    • Report Reason: {5}
  • \
", - sender_user.to_owned(), - pdu.event_id.clone(), - pdu.room_id.clone(), - pdu.sender.clone(), - body.score.unwrap_or_else(|| ruma::Int::from(0)), - HtmlEscape(body.reason.as_deref().unwrap_or("")) - ), - )); + sender_user.to_owned(), + pdu.event_id.clone(), + pdu.room_id.clone(), + pdu.sender.clone(), + body.score.unwrap_or_else(|| ruma::Int::from(0)), + HtmlEscape(body.reason.as_deref().unwrap_or("")) + ), + )); - // even though this is kinda security by obscurity, let's still make a small random delay sending a successful response - // per spec suggestion regarding enumerating for potential events existing in our server. - let time_to_wait = rand::thread_rng().gen_range(8..21); - debug!( - "Got successful /report request, waiting {} seconds before sending successful response.", - time_to_wait - ); - sleep(Duration::from_secs(time_to_wait)).await; + // even though this is kinda security by obscurity, let's still make a small + // random delay sending a successful response per spec suggestion regarding + // enumerating for potential events existing in our server. + let time_to_wait = rand::thread_rng().gen_range(8..21); + debug!( + "Got successful /report request, waiting {} seconds before sending successful response.", + time_to_wait + ); + sleep(Duration::from_secs(time_to_wait)).await; - Ok(report_content::v3::Response {}) + Ok(report_content::v3::Response {}) } diff --git a/src/api/client_server/room.rs b/src/api/client_server/room.rs index c888a90e..c0247d27 100644 --- a/src/api/client_server/room.rs +++ b/src/api/client_server/room.rs @@ -1,35 +1,34 @@ -use crate::{ - api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma, -}; +use std::{cmp::max, collections::BTreeMap, sync::Arc}; + use ruma::{ - api::client::{ - error::ErrorKind, - room::{self, aliases, create_room, get_room_event, upgrade_room}, - }, - events::{ - room::{ - canonical_alias::RoomCanonicalAliasEventContent, - create::RoomCreateEventContent, - guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - join_rules::{JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, - name::RoomNameEventContent, - power_levels::RoomPowerLevelsEventContent, - tombstone::RoomTombstoneEventContent, - topic::RoomTopicEventContent, - }, - StateEventType, TimelineEventType, - }, - int, - serde::JsonObject, - CanonicalJsonObject, CanonicalJsonValue, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, - RoomVersionId, + api::client::{ + error::ErrorKind, + room::{self, aliases, create_room, get_room_event, upgrade_room}, + }, + events::{ + room::{ + canonical_alias::RoomCanonicalAliasEventContent, + create::RoomCreateEventContent, + guest_access::{GuestAccess, RoomGuestAccessEventContent}, + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + name::RoomNameEventContent, + power_levels::RoomPowerLevelsEventContent, + tombstone::RoomTombstoneEventContent, + topic::RoomTopicEventContent, + }, + StateEventType, TimelineEventType, + }, + int, + serde::JsonObject, + CanonicalJsonObject, CanonicalJsonValue, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId, }; use serde_json::{json, value::to_raw_value}; -use std::{cmp::max, collections::BTreeMap, sync::Arc}; use tracing::{debug, error, info, warn}; +use crate::{api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma}; + /// # `POST /_matrix/client/v3/createRoom` /// /// Creates a new room. @@ -46,634 +45,544 @@ use tracing::{debug, error, info, warn}; /// - Send events listed in initial state /// - Send events implied by `name` and `topic` /// - Send invite events -pub async fn create_room_route( - body: Ruma, -) -> Result { - use create_room::v3::RoomPreset; +pub async fn create_room_route(body: Ruma) -> Result { + use create_room::v3::RoomPreset; - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services().globals.allow_room_creation() - && !&body.from_appservice - && !services().users.is_admin(sender_user)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Room creation has been disabled.", - )); - } + if !services().globals.allow_room_creation() && !&body.from_appservice && !services().users.is_admin(sender_user)? { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Room creation has been disabled.")); + } - let room_id: OwnedRoomId; + let room_id: OwnedRoomId; - // checks if the user specified an explicit (custom) room_id to be created with in request body. - // falls back to normal generated room ID if not specified. - if let Some(CanonicalJsonValue::Object(json_body)) = &body.json_body { - match json_body.get("room_id") { - Some(custom_room_id) => { - let custom_room_id_s = custom_room_id.to_string(); + // checks if the user specified an explicit (custom) room_id to be created with + // in request body. falls back to normal generated room ID if not specified. + if let Some(CanonicalJsonValue::Object(json_body)) = &body.json_body { + match json_body.get("room_id") { + Some(custom_room_id) => { + let custom_room_id_s = custom_room_id.to_string(); - // do some checks on the custom room ID similar to room aliases - if custom_room_id_s.contains(':') { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Custom room ID contained `:` which is not allowed. Please note that this expects a localpart, not the full room ID.", - )); - } else if custom_room_id_s.contains(char::is_whitespace) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Custom room ID contained spaces which is not valid.", - )); - } else if custom_room_id_s.len() > 255 { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Custom room ID is too long.", - )); - } + // do some checks on the custom room ID similar to room aliases + if custom_room_id_s.contains(':') { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Custom room ID contained `:` which is not allowed. Please note that this expects a \ + localpart, not the full room ID.", + )); + } else if custom_room_id_s.contains(char::is_whitespace) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Custom room ID contained spaces which is not valid.", + )); + } else if custom_room_id_s.len() > 255 { + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Custom room ID is too long.")); + } - // apply forbidden room alias checks to custom room IDs too - if services() - .globals - .forbidden_room_names() - .is_match(&custom_room_id_s) - { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Custom room ID is forbidden.", - )); - } + // apply forbidden room alias checks to custom room IDs too + if services().globals.forbidden_room_names().is_match(&custom_room_id_s) { + return Err(Error::BadRequest(ErrorKind::Unknown, "Custom room ID is forbidden.")); + } - let full_room_id = "!".to_owned() - + &custom_room_id_s.replace('"', "") - + ":" - + services().globals.server_name().as_ref(); - debug!("Full room ID: {}", full_room_id); + let full_room_id = "!".to_owned() + + &custom_room_id_s.replace('"', "") + + ":" + services().globals.server_name().as_ref(); + debug!("Full room ID: {}", full_room_id); - room_id = RoomId::parse(full_room_id).map_err(|e| { - info!( - "User attempted to create room with custom room ID but failed parsing: {}", - e - ); - Error::BadRequest( - ErrorKind::InvalidParam, - "Custom room ID could not be parsed", - ) - })?; - } - None => room_id = RoomId::new(services().globals.server_name()), - } - } else { - room_id = RoomId::new(services().globals.server_name()); - } + room_id = RoomId::parse(full_room_id).map_err(|e| { + info!("User attempted to create room with custom room ID but failed parsing: {}", e); + Error::BadRequest(ErrorKind::InvalidParam, "Custom room ID could not be parsed") + })?; + }, + None => room_id = RoomId::new(services().globals.server_name()), + } + } else { + room_id = RoomId::new(services().globals.server_name()); + } - // check if room ID doesn't already exist instead of erroring on auth check - if services().rooms.short.get_shortroomid(&room_id)?.is_some() { - return Err(Error::BadRequest( - ErrorKind::RoomInUse, - "Room with that custom room ID already exists", - )); - } + // check if room ID doesn't already exist instead of erroring on auth check + if services().rooms.short.get_shortroomid(&room_id)?.is_some() { + return Err(Error::BadRequest( + ErrorKind::RoomInUse, + "Room with that custom room ID already exists", + )); + } - services().rooms.short.get_or_create_shortroomid(&room_id)?; + services().rooms.short.get_or_create_shortroomid(&room_id)?; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - let alias: Option = - body.room_alias_name - .as_ref() - .map_or(Ok(None), |localpart| { + let alias: Option = body.room_alias_name.as_ref().map_or(Ok(None), |localpart| { + // Basic checks on the room alias validity + if localpart.contains(':') { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room alias contained `:` which is not allowed. Please note that this expects a localpart, not the \ + full room alias.", + )); + } else if localpart.contains(char::is_whitespace) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room alias contained spaces which is not a valid room alias.", + )); + } else if localpart.len() > 255 { + // there is nothing spec-wise saying to check the limit of this, + // however absurdly long room aliases are guaranteed to be unreadable or done + // maliciously. there is no reason a room alias should even exceed 100 + // characters as is. generally in spec, 255 is matrix's fav number + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room alias is excessively long, clients may not be able to handle this. Please shorten it.", + )); + } else if localpart.contains('"') { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room alias contained `\"` which is not allowed.", + )); + } - // Basic checks on the room alias validity - if localpart.contains(':') { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room alias contained `:` which is not allowed. Please note that this expects a localpart, not the full room alias.", - )); - } else if localpart.contains(char::is_whitespace) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room alias contained spaces which is not a valid room alias.", - )); - } else if localpart.len() > 255 { - // there is nothing spec-wise saying to check the limit of this, - // however absurdly long room aliases are guaranteed to be unreadable or done maliciously. - // there is no reason a room alias should even exceed 100 characters as is. - // generally in spec, 255 is matrix's fav number - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room alias is excessively long, clients may not be able to handle this. Please shorten it.", - )); - } else if localpart.contains('"') { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room alias contained `\"` which is not allowed.", - )); - } + // check if room alias is forbidden + if services().globals.forbidden_room_names().is_match(localpart) { + return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden.")); + } - // check if room alias is forbidden - if services() - .globals - .forbidden_room_names() - .is_match(localpart) - { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Room alias name is forbidden.", - )); - } + let alias = + RoomAliasId::parse(format!("#{}:{}", localpart, services().globals.server_name())).map_err(|e| { + warn!("Failed to parse room alias for room ID {}: {e}", room_id); + Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.") + })?; - let alias = RoomAliasId::parse(format!( - "#{}:{}", - localpart, - services().globals.server_name() - )) - .map_err(|e| { - warn!("Failed to parse room alias for room ID {}: {e}", room_id); - Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.") - })?; + if services().rooms.alias.resolve_local_alias(&alias)?.is_some() { + Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists.")) + } else { + Ok(Some(alias)) + } + })?; - if services() - .rooms - .alias - .resolve_local_alias(&alias)? - .is_some() - { - Err(Error::BadRequest( - ErrorKind::RoomInUse, - "Room alias already exists.", - )) - } else { - Ok(Some(alias)) - } - })?; + let room_version = match body.room_version.clone() { + Some(room_version) => { + if services().globals.supported_room_versions().contains(&room_version) { + room_version + } else { + return Err(Error::BadRequest( + ErrorKind::UnsupportedRoomVersion, + "This server does not support that room version.", + )); + } + }, + None => services().globals.default_room_version(), + }; - let room_version = match body.room_version.clone() { - Some(room_version) => { - if services() - .globals - .supported_room_versions() - .contains(&room_version) - { - room_version - } else { - return Err(Error::BadRequest( - ErrorKind::UnsupportedRoomVersion, - "This server does not support that room version.", - )); - } - } - None => services().globals.default_room_version(), - }; + let content = match &body.creation_content { + Some(content) => { + let mut content = content.deserialize_as::().map_err(|e| { + error!("Failed to deserialise content as canonical JSON: {}", e); + Error::bad_database("Failed to deserialise content as canonical JSON.") + })?; + match room_version { + RoomVersionId::V1 + | RoomVersionId::V2 + | RoomVersionId::V3 + | RoomVersionId::V4 + | RoomVersionId::V5 + | RoomVersionId::V6 + | RoomVersionId::V7 + | RoomVersionId::V8 + | RoomVersionId::V9 + | RoomVersionId::V10 => { + content.insert( + "creator".into(), + json!(&sender_user).try_into().map_err(|e| { + info!("Invalid creation content: {e}"); + Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") + })?, + ); + }, + RoomVersionId::V11 => {}, // V11 removed the "creator" key + _ => { + warn!("Unexpected or unsupported room version {}", room_version); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + } - let content = match &body.creation_content { - Some(content) => { - let mut content = content - .deserialize_as::() - .map_err(|e| { - error!("Failed to deserialise content as canonical JSON: {}", e); - Error::bad_database("Failed to deserialise content as canonical JSON.") - })?; - match room_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { - content.insert( - "creator".into(), - json!(&sender_user).try_into().map_err(|e| { - info!("Invalid creation content: {e}"); - Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") - })?, - ); - } - RoomVersionId::V11 => {} // V11 removed the "creator" key - _ => { - warn!("Unexpected or unsupported room version {}", room_version); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - } - } + content.insert( + "room_version".into(), + json!(room_version.as_str()) + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))?, + ); + content + }, + None => { + // TODO: Add correct value for v11 + let content = match room_version { + RoomVersionId::V1 + | RoomVersionId::V2 + | RoomVersionId::V3 + | RoomVersionId::V4 + | RoomVersionId::V5 + | RoomVersionId::V6 + | RoomVersionId::V7 + | RoomVersionId::V8 + | RoomVersionId::V9 + | RoomVersionId::V10 => RoomCreateEventContent::new_v1(sender_user.clone()), + RoomVersionId::V11 => RoomCreateEventContent::new_v11(), + _ => { + warn!("Unexpected or unsupported room version {}", room_version); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + }; + let mut content = serde_json::from_str::( + to_raw_value(&content) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))? + .get(), + ) + .unwrap(); + content.insert( + "room_version".into(), + json!(room_version.as_str()) + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))?, + ); + content + }, + }; - content.insert( - "room_version".into(), - json!(room_version.as_str()).try_into().map_err(|_| { - Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") - })?, - ); - content - } - None => { - // TODO: Add correct value for v11 - let content = match room_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => RoomCreateEventContent::new_v1(sender_user.clone()), - RoomVersionId::V11 => RoomCreateEventContent::new_v11(), - _ => { - warn!("Unexpected or unsupported room version {}", room_version); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - } - }; - let mut content = serde_json::from_str::( - to_raw_value(&content) - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))? - .get(), - ) - .unwrap(); - content.insert( - "room_version".into(), - json!(room_version.as_str()).try_into().map_err(|_| { - Error::BadRequest(ErrorKind::BadJson, "Invalid creation content") - })?, - ); - content - } - }; + // Validate creation content + let de_result = + serde_json::from_str::(to_raw_value(&content).expect("Invalid creation content").get()); - // Validate creation content - let de_result = serde_json::from_str::( - to_raw_value(&content) - .expect("Invalid creation content") - .get(), - ); + if de_result.is_err() { + return Err(Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")); + } - if de_result.is_err() { - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Invalid creation content", - )); - } + // 1. The room create event + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomCreate, + content: to_raw_value(&content).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; - // 1. The room create event - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &room_id, - &state_lock, - ) - .await?; + // 2. Let the room creator join + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: Some(body.is_direct), + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; - // 2. Let the room creator join - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: Some(body.is_direct), - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - }, - sender_user, - &room_id, - &state_lock, - ) - .await?; + // 3. Power levels - // 3. Power levels + // Figure out preset. We need it for preset specific events + let preset = body.preset.clone().unwrap_or(match &body.visibility { + room::Visibility::Public => RoomPreset::PublicChat, + _ => RoomPreset::PrivateChat, // Room visibility should not be custom + }); - // Figure out preset. We need it for preset specific events - let preset = body.preset.clone().unwrap_or(match &body.visibility { - room::Visibility::Public => RoomPreset::PublicChat, - _ => RoomPreset::PrivateChat, // Room visibility should not be custom - }); + let mut users = BTreeMap::new(); + users.insert(sender_user.clone(), int!(100)); - let mut users = BTreeMap::new(); - users.insert(sender_user.clone(), int!(100)); + if preset == RoomPreset::TrustedPrivateChat { + for invite_ in &body.invite { + users.insert(invite_.clone(), int!(100)); + } + } - if preset == RoomPreset::TrustedPrivateChat { - for invite_ in &body.invite { - users.insert(invite_.clone(), int!(100)); - } - } + let mut power_levels_content = serde_json::to_value(RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"); - let mut power_levels_content = serde_json::to_value(RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"); + if let Some(power_level_content_override) = &body.power_level_content_override { + let json: JsonObject = serde_json::from_str(power_level_content_override.json().get()) + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid power_level_content_override."))?; - if let Some(power_level_content_override) = &body.power_level_content_override { - let json: JsonObject = serde_json::from_str(power_level_content_override.json().get()) - .map_err(|_| { - Error::BadRequest(ErrorKind::BadJson, "Invalid power_level_content_override.") - })?; + for (key, value) in json { + power_levels_content[key] = value; + } + } - for (key, value) in json { - power_levels_content[key] = value; - } - } + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomPowerLevels, + content: to_raw_value(&power_levels_content).expect("to_raw_value always works on serde_json::Value"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&power_levels_content) - .expect("to_raw_value always works on serde_json::Value"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &room_id, - &state_lock, - ) - .await?; + // 4. Canonical room alias + if let Some(room_alias_id) = &alias { + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomCanonicalAlias, + content: to_raw_value(&RoomCanonicalAliasEventContent { + alias: Some(room_alias_id.to_owned()), + alt_aliases: vec![], + }) + .expect("We checked that alias earlier, it must be fine"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; + } - // 4. Canonical room alias - if let Some(room_alias_id) = &alias { - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCanonicalAlias, - content: to_raw_value(&RoomCanonicalAliasEventContent { - alias: Some(room_alias_id.to_owned()), - alt_aliases: vec![], - }) - .expect("We checked that alias earlier, it must be fine"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &room_id, - &state_lock, - ) - .await?; - } + // 5. Events set by preset - // 5. Events set by preset + // 5.1 Join Rules + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomJoinRules, + content: to_raw_value(&RoomJoinRulesEventContent::new(match preset { + RoomPreset::PublicChat => JoinRule::Public, + // according to spec "invite" is the default + _ => JoinRule::Invite, + })) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; - // 5.1 Join Rules - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomJoinRules, - content: to_raw_value(&RoomJoinRulesEventContent::new(match preset { - RoomPreset::PublicChat => JoinRule::Public, - // according to spec "invite" is the default - _ => JoinRule::Invite, - })) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &room_id, - &state_lock, - ) - .await?; + // 5.2 History Visibility + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomHistoryVisibility, + content: to_raw_value(&RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared)) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; - // 5.2 History Visibility - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomHistoryVisibility, - content: to_raw_value(&RoomHistoryVisibilityEventContent::new( - HistoryVisibility::Shared, - )) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &room_id, - &state_lock, - ) - .await?; + // 5.3 Guest Access + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomGuestAccess, + content: to_raw_value(&RoomGuestAccessEventContent::new(match preset { + RoomPreset::PublicChat => GuestAccess::Forbidden, + _ => GuestAccess::CanJoin, + })) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; - // 5.3 Guest Access - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomGuestAccess, - content: to_raw_value(&RoomGuestAccessEventContent::new(match preset { - RoomPreset::PublicChat => GuestAccess::Forbidden, - _ => GuestAccess::CanJoin, - })) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &room_id, - &state_lock, - ) - .await?; + // 6. Events listed in initial_state + for event in &body.initial_state { + let mut pdu_builder = event.deserialize_as::().map_err(|e| { + warn!("Invalid initial state event: {:?}", e); + Error::BadRequest(ErrorKind::InvalidParam, "Invalid initial state event.") + })?; - // 6. Events listed in initial_state - for event in &body.initial_state { - let mut pdu_builder = event.deserialize_as::().map_err(|e| { - warn!("Invalid initial state event: {:?}", e); - Error::BadRequest(ErrorKind::InvalidParam, "Invalid initial state event.") - })?; + // Implicit state key defaults to "" + pdu_builder.state_key.get_or_insert_with(|| "".to_owned()); - // Implicit state key defaults to "" - pdu_builder.state_key.get_or_insert_with(|| "".to_owned()); + // Silently skip encryption events if they are not allowed + if pdu_builder.event_type == TimelineEventType::RoomEncryption && !services().globals.allow_encryption() { + continue; + } - // Silently skip encryption events if they are not allowed - if pdu_builder.event_type == TimelineEventType::RoomEncryption - && !services().globals.allow_encryption() - { - continue; - } + services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await?; + } - services() - .rooms - .timeline - .build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock) - .await?; - } + // 7. Events implied by name and topic + if let Some(name) = &body.name { + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomName, + content: to_raw_value(&RoomNameEventContent::new(name.clone())) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; + } - // 7. Events implied by name and topic - if let Some(name) = &body.name { - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomName, - content: to_raw_value(&RoomNameEventContent::new(name.clone())) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &room_id, - &state_lock, - ) - .await?; - } + if let Some(topic) = &body.topic { + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomTopic, + content: to_raw_value(&RoomTopicEventContent { + topic: topic.clone(), + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &room_id, + &state_lock, + ) + .await?; + } - if let Some(topic) = &body.topic { - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomTopic, - content: to_raw_value(&RoomTopicEventContent { - topic: topic.clone(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &room_id, - &state_lock, - ) - .await?; - } + // 8. Events implied by invite (and TODO: invite_3pid) + drop(state_lock); + for user_id in &body.invite { + let _ = invite_helper(sender_user, user_id, &room_id, None, body.is_direct).await; + } - // 8. Events implied by invite (and TODO: invite_3pid) - drop(state_lock); - for user_id in &body.invite { - let _ = invite_helper(sender_user, user_id, &room_id, None, body.is_direct).await; - } + // Homeserver specific stuff + if let Some(alias) = alias { + services().rooms.alias.set_alias(&alias, &room_id)?; + } - // Homeserver specific stuff - if let Some(alias) = alias { - services().rooms.alias.set_alias(&alias, &room_id)?; - } + if body.visibility == room::Visibility::Public { + services().rooms.directory.set_public(&room_id)?; + } - if body.visibility == room::Visibility::Public { - services().rooms.directory.set_public(&room_id)?; - } + info!("{} created a room", sender_user); - info!("{} created a room", sender_user); - - Ok(create_room::v3::Response::new(room_id)) + Ok(create_room::v3::Response::new(room_id)) } /// # `GET /_matrix/client/r0/rooms/{roomId}/event/{eventId}` /// /// Gets a single event. /// -/// - You have to currently be joined to the room (TODO: Respect history visibility) -pub async fn get_room_event_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +/// - You have to currently be joined to the room (TODO: Respect history +/// visibility) +pub async fn get_room_event_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services() - .rooms - .timeline - .get_pdu(&body.event_id)? - .ok_or_else(|| { - warn!("Event not found, event ID: {:?}", &body.event_id); - Error::BadRequest(ErrorKind::NotFound, "Event not found.") - })?; + let event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or_else(|| { + warn!("Event not found, event ID: {:?}", &body.event_id); + Error::BadRequest(ErrorKind::NotFound, "Event not found.") + })?; - if !services().rooms.state_accessor.user_can_see_event( - sender_user, - &event.room_id, - &body.event_id, - )? { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view this event.", - )); - } + if !services().rooms.state_accessor.user_can_see_event(sender_user, &event.room_id, &body.event_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this event.", + )); + } - let mut event = (*event).clone(); - event.add_age()?; + let mut event = (*event).clone(); + event.add_age()?; - Ok(get_room_event::v3::Response { - event: event.to_room_event(), - }) + Ok(get_room_event::v3::Response { + event: event.to_room_event(), + }) } /// # `GET /_matrix/client/r0/rooms/{roomId}/aliases` /// /// Lists all aliases of the room. /// -/// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if history_visibility is world readable -pub async fn get_room_aliases_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +/// - Only users joined to the room are allowed to call this TODO: Allow any +/// user to call it if history_visibility is world readable +pub async fn get_room_aliases_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .rooms - .state_cache - .is_joined(sender_user, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view this room.", - )); - } + if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this room.", + )); + } - Ok(aliases::v3::Response { - aliases: services() - .rooms - .alias - .local_aliases_for_room(&body.room_id) - .filter_map(std::result::Result::ok) - .collect(), - }) + Ok(aliases::v3::Response { + aliases: services() + .rooms + .alias + .local_aliases_for_room(&body.room_id) + .filter_map(std::result::Result::ok) + .collect(), + }) } /// # `POST /_matrix/client/r0/rooms/{roomId}/upgrade` @@ -686,297 +595,256 @@ pub async fn get_room_aliases_route( /// - Transfers some state events /// - Moves local aliases /// - Modifies old room power levels to prevent users from speaking -pub async fn upgrade_room_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn upgrade_room_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .globals - .supported_room_versions() - .contains(&body.new_version) - { - return Err(Error::BadRequest( - ErrorKind::UnsupportedRoomVersion, - "This server does not support that room version.", - )); - } + if !services().globals.supported_room_versions().contains(&body.new_version) { + return Err(Error::BadRequest( + ErrorKind::UnsupportedRoomVersion, + "This server does not support that room version.", + )); + } - // Create a replacement room - let replacement_room = RoomId::new(services().globals.server_name()); - services() - .rooms - .short - .get_or_create_shortroomid(&replacement_room)?; + // Create a replacement room + let replacement_room = RoomId::new(services().globals.server_name()); + services().rooms.short.get_or_create_shortroomid(&replacement_room)?; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - // Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further - // Fail if the sender does not have the required permissions - let tombstone_event_id = services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomTombstone, - content: to_raw_value(&RoomTombstoneEventContent { - body: "This room has been replaced".to_owned(), - replacement_room: replacement_room.clone(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; + // Send a m.room.tombstone event to the old room to indicate that it is not + // intended to be used any further Fail if the sender does not have the required + // permissions + let tombstone_event_id = services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomTombstone, + content: to_raw_value(&RoomTombstoneEventContent { + body: "This room has been replaced".to_owned(), + replacement_room: replacement_room.clone(), + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &body.room_id, + &state_lock, + ) + .await?; - // Change lock to replacement room - drop(state_lock); - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(replacement_room.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + // Change lock to replacement room + drop(state_lock); + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(replacement_room.clone()).or_default()); + let state_lock = mutex_state.lock().await; - // Get the old room creation event - let mut create_event_content = serde_json::from_str::( - services() - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; + // Get the old room creation event + let mut create_event_content = serde_json::from_str::( + services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &StateEventType::RoomCreate, "")? + .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Invalid room event in database."))?; - // Use the m.room.tombstone event as the predecessor - let predecessor = Some(ruma::events::room::create::PreviousRoom::new( - body.room_id.clone(), - (*tombstone_event_id).to_owned(), - )); + // Use the m.room.tombstone event as the predecessor + let predecessor = Some(ruma::events::room::create::PreviousRoom::new( + body.room_id.clone(), + (*tombstone_event_id).to_owned(), + )); - // Send a m.room.create event containing a predecessor field and the applicable room_version - match body.new_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { - create_event_content.insert( - "creator".into(), - json!(&sender_user).try_into().map_err(|e| { - info!("Error forming creation event: {e}"); - Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") - })?, - ); - } - RoomVersionId::V11 => { - // "creator" key no longer exists in V11 rooms - create_event_content.remove("creator"); - } - _ => { - warn!( - "Unexpected or unsupported room version {}", - body.new_version - ); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - } - } + // Send a m.room.create event containing a predecessor field and the applicable + // room_version + match body.new_version { + RoomVersionId::V1 + | RoomVersionId::V2 + | RoomVersionId::V3 + | RoomVersionId::V4 + | RoomVersionId::V5 + | RoomVersionId::V6 + | RoomVersionId::V7 + | RoomVersionId::V8 + | RoomVersionId::V9 + | RoomVersionId::V10 => { + create_event_content.insert( + "creator".into(), + json!(&sender_user).try_into().map_err(|e| { + info!("Error forming creation event: {e}"); + Error::BadRequest(ErrorKind::BadJson, "Error forming creation event") + })?, + ); + }, + RoomVersionId::V11 => { + // "creator" key no longer exists in V11 rooms + create_event_content.remove("creator"); + }, + _ => { + warn!("Unexpected or unsupported room version {}", body.new_version); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + } - create_event_content.insert( - "room_version".into(), - json!(&body.new_version) - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, - ); - create_event_content.insert( - "predecessor".into(), - json!(predecessor) - .try_into() - .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, - ); + create_event_content.insert( + "room_version".into(), + json!(&body.new_version) + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, + ); + create_event_content.insert( + "predecessor".into(), + json!(predecessor) + .try_into() + .map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"))?, + ); - // Validate creation event content - let de_result = serde_json::from_str::( - to_raw_value(&create_event_content) - .expect("Error forming creation event") - .get(), - ); + // Validate creation event content + let de_result = serde_json::from_str::( + to_raw_value(&create_event_content).expect("Error forming creation event").get(), + ); - if de_result.is_err() { - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Error forming creation event", - )); - } + if de_result.is_err() { + return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")); + } - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&create_event_content) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &replacement_room, - &state_lock, - ) - .await?; + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomCreate, + content: to_raw_value(&create_event_content).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &replacement_room, + &state_lock, + ) + .await?; - // Join the new room - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: services().users.displayname(sender_user)?, - avatar_url: services().users.avatar_url(sender_user)?, - is_direct: None, - third_party_invite: None, - blurhash: services().users.blurhash(sender_user)?, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(sender_user.to_string()), - redacts: None, - }, - sender_user, - &replacement_room, - &state_lock, - ) - .await?; + // Join the new room + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: services().users.displayname(sender_user)?, + avatar_url: services().users.avatar_url(sender_user)?, + is_direct: None, + third_party_invite: None, + blurhash: services().users.blurhash(sender_user)?, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(sender_user.to_string()), + redacts: None, + }, + sender_user, + &replacement_room, + &state_lock, + ) + .await?; - // Recommended transferable state events list from the specs - let transferable_state_events = vec![ - StateEventType::RoomServerAcl, - StateEventType::RoomEncryption, - StateEventType::RoomName, - StateEventType::RoomAvatar, - StateEventType::RoomTopic, - StateEventType::RoomGuestAccess, - StateEventType::RoomHistoryVisibility, - StateEventType::RoomJoinRules, - StateEventType::RoomPowerLevels, - ]; + // Recommended transferable state events list from the specs + let transferable_state_events = vec![ + StateEventType::RoomServerAcl, + StateEventType::RoomEncryption, + StateEventType::RoomName, + StateEventType::RoomAvatar, + StateEventType::RoomTopic, + StateEventType::RoomGuestAccess, + StateEventType::RoomHistoryVisibility, + StateEventType::RoomJoinRules, + StateEventType::RoomPowerLevels, + ]; - // Replicate transferable state events to the new room - for event_type in transferable_state_events { - let event_content = - match services() - .rooms - .state_accessor - .room_state_get(&body.room_id, &event_type, "")? - { - Some(v) => v.content.clone(), - None => continue, // Skipping missing events. - }; + // Replicate transferable state events to the new room + for event_type in transferable_state_events { + let event_content = match services().rooms.state_accessor.room_state_get(&body.room_id, &event_type, "")? { + Some(v) => v.content.clone(), + None => continue, // Skipping missing events. + }; - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: event_type.to_string().into(), - content: event_content, - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &replacement_room, - &state_lock, - ) - .await?; - } + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: event_type.to_string().into(), + content: event_content, + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &replacement_room, + &state_lock, + ) + .await?; + } - // Moves any local aliases to the new room - for alias in services() - .rooms - .alias - .local_aliases_for_room(&body.room_id) - .filter_map(std::result::Result::ok) - { - services() - .rooms - .alias - .set_alias(&alias, &replacement_room)?; - } + // Moves any local aliases to the new room + for alias in services().rooms.alias.local_aliases_for_room(&body.room_id).filter_map(std::result::Result::ok) { + services().rooms.alias.set_alias(&alias, &replacement_room)?; + } - // Get the old room power levels - let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( - services() - .rooms - .state_accessor - .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? - .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? - .content - .get(), - ) - .map_err(|_| Error::bad_database("Invalid room event in database."))?; + // Get the old room power levels + let mut power_levels_event_content: RoomPowerLevelsEventContent = serde_json::from_str( + services() + .rooms + .state_accessor + .room_state_get(&body.room_id, &StateEventType::RoomPowerLevels, "")? + .ok_or_else(|| Error::bad_database("Found room without m.room.create event."))? + .content + .get(), + ) + .map_err(|_| Error::bad_database("Invalid room event in database."))?; - // Setting events_default and invite to the greater of 50 and users_default + 1 - let new_level = max(int!(50), power_levels_event_content.users_default + int!(1)); - power_levels_event_content.events_default = new_level; - power_levels_event_content.invite = new_level; + // Setting events_default and invite to the greater of 50 and users_default + 1 + let new_level = max(int!(50), power_levels_event_content.users_default + int!(1)); + power_levels_event_content.events_default = new_level; + power_levels_event_content.invite = new_level; - // Modify the power levels in the old room to prevent sending of events and inviting new users - let _ = services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&power_levels_event_content) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - sender_user, - &body.room_id, - &state_lock, - ) - .await?; + // Modify the power levels in the old room to prevent sending of events and + // inviting new users + let _ = services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomPowerLevels, + content: to_raw_value(&power_levels_event_content).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + sender_user, + &body.room_id, + &state_lock, + ) + .await?; - drop(state_lock); + drop(state_lock); - // Return the replacement room id - Ok(upgrade_room::v3::Response { replacement_room }) + // Return the replacement room id + Ok(upgrade_room::v3::Response { + replacement_room, + }) } diff --git a/src/api/client_server/search.rs b/src/api/client_server/search.rs index 43e4c9f4..67960cc8 100644 --- a/src/api/client_server/search.rs +++ b/src/api/client_server/search.rs @@ -1,138 +1,120 @@ -use crate::{services, Error, Result, Ruma}; +use std::collections::BTreeMap; + use ruma::api::client::{ - error::ErrorKind, - search::search_events::{ - self, - v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, - }, + error::ErrorKind, + search::search_events::{ + self, + v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, + }, }; -use std::collections::BTreeMap; +use crate::{services, Error, Result, Ruma}; /// # `POST /_matrix/client/r0/search` /// /// Searches rooms for messages. /// -/// - Only works if the user is currently joined to the room (TODO: Respect history visibility) -pub async fn search_events_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +/// - Only works if the user is currently joined to the room (TODO: Respect +/// history visibility) +pub async fn search_events_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let search_criteria = body.search_categories.room_events.as_ref().unwrap(); - let filter = &search_criteria.filter; + let search_criteria = body.search_categories.room_events.as_ref().unwrap(); + let filter = &search_criteria.filter; - let room_ids = filter.rooms.clone().unwrap_or_else(|| { - services() - .rooms - .state_cache - .rooms_joined(sender_user) - .filter_map(std::result::Result::ok) - .collect() - }); + let room_ids = filter.rooms.clone().unwrap_or_else(|| { + services().rooms.state_cache.rooms_joined(sender_user).filter_map(std::result::Result::ok).collect() + }); - // Use limit or else 10, with maximum 100 - let limit = filter.limit.map_or(10, u64::from).min(100) as usize; + // Use limit or else 10, with maximum 100 + let limit = filter.limit.map_or(10, u64::from).min(100) as usize; - let mut searches = Vec::new(); + let mut searches = Vec::new(); - for room_id in room_ids { - if !services() - .rooms - .state_cache - .is_joined(sender_user, &room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view this room.", - )); - } + for room_id in room_ids { + if !services().rooms.state_cache.is_joined(sender_user, &room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view this room.", + )); + } - if let Some(search) = services() - .rooms - .search - .search_pdus(&room_id, &search_criteria.search_term)? - { - searches.push(search.0.peekable()); - } - } + if let Some(search) = services().rooms.search.search_pdus(&room_id, &search_criteria.search_term)? { + searches.push(search.0.peekable()); + } + } - let skip = match body.next_batch.as_ref().map(|s| s.parse()) { - Some(Ok(s)) => s, - Some(Err(_)) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid next_batch token.", - )) - } - None => 0, // Default to the start - }; + let skip = match body.next_batch.as_ref().map(|s| s.parse()) { + Some(Ok(s)) => s, + Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")), + None => 0, // Default to the start + }; - let mut results = Vec::new(); - for _ in 0..skip + limit { - if let Some(s) = searches - .iter_mut() - .map(|s| (s.peek().cloned(), s)) - .max_by_key(|(peek, _)| peek.clone()) - .and_then(|(_, i)| i.next()) - { - results.push(s); - } - } + let mut results = Vec::new(); + for _ in 0..skip + limit { + if let Some(s) = searches + .iter_mut() + .map(|s| (s.peek().cloned(), s)) + .max_by_key(|(peek, _)| peek.clone()) + .and_then(|(_, i)| i.next()) + { + results.push(s); + } + } - let results: Vec<_> = results - .iter() - .filter_map(|result| { - services() - .rooms - .timeline - .get_pdu_from_id(result) - .ok()? - .filter(|pdu| { - services() - .rooms - .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) - .unwrap_or(false) - }) - .map(|pdu| pdu.to_room_event()) - }) - .map(|result| { - Ok::<_, Error>(SearchResult { - context: EventContextResult { - end: None, - events_after: Vec::new(), - events_before: Vec::new(), - profile_info: BTreeMap::new(), - start: None, - }, - rank: None, - result: Some(result), - }) - }) - .filter_map(std::result::Result::ok) - .skip(skip) - .take(limit) - .collect(); + let results: Vec<_> = results + .iter() + .filter_map(|result| { + services() + .rooms + .timeline + .get_pdu_from_id(result) + .ok()? + .filter(|pdu| { + services() + .rooms + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .unwrap_or(false) + }) + .map(|pdu| pdu.to_room_event()) + }) + .map(|result| { + Ok::<_, Error>(SearchResult { + context: EventContextResult { + end: None, + events_after: Vec::new(), + events_before: Vec::new(), + profile_info: BTreeMap::new(), + start: None, + }, + rank: None, + result: Some(result), + }) + }) + .filter_map(std::result::Result::ok) + .skip(skip) + .take(limit) + .collect(); - let next_batch = if results.len() < limit { - None - } else { - Some((skip + limit).to_string()) - }; + let next_batch = if results.len() < limit { + None + } else { + Some((skip + limit).to_string()) + }; - Ok(search_events::v3::Response::new(ResultCategories { - room_events: ResultRoomEvents { - count: Some((results.len() as u32).into()), // TODO: set this to none. Element shouldn't depend on it - groups: BTreeMap::new(), // TODO - next_batch, - results, - state: BTreeMap::new(), // TODO - highlights: search_criteria - .search_term - .split_terminator(|c: char| !c.is_alphanumeric()) - .map(str::to_lowercase) - .collect(), - }, - })) + Ok(search_events::v3::Response::new(ResultCategories { + room_events: ResultRoomEvents { + count: Some((results.len() as u32).into()), // TODO: set this to none. Element shouldn't depend on it + groups: BTreeMap::new(), // TODO + next_batch, + results, + state: BTreeMap::new(), // TODO + highlights: search_criteria + .search_term + .split_terminator(|c: char| !c.is_alphanumeric()) + .map(str::to_lowercase) + .collect(), + }, + })) } diff --git a/src/api/client_server/session.rs b/src/api/client_server/session.rs index 56250a0c..dc7ee096 100644 --- a/src/api/client_server/session.rs +++ b/src/api/client_server/session.rs @@ -1,246 +1,221 @@ -use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; -use crate::{services, utils, Error, Result, Ruma}; use argon2::{PasswordHash, PasswordVerifier}; use ruma::{ - api::client::{ - error::ErrorKind, - session::{ - get_login_types::{ - self, - v3::{ApplicationServiceLoginType, PasswordLoginType}, - }, - login::{ - self, - v3::{DiscoveryInfo, HomeserverInfo}, - }, - logout, logout_all, - }, - uiaa::UserIdentifier, - }, - UserId, + api::client::{ + error::ErrorKind, + session::{ + get_login_types::{ + self, + v3::{ApplicationServiceLoginType, PasswordLoginType}, + }, + login::{ + self, + v3::{DiscoveryInfo, HomeserverInfo}, + }, + logout, logout_all, + }, + uiaa::UserIdentifier, + }, + UserId, }; use serde::Deserialize; use tracing::{debug, error, info, warn}; +use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH}; +use crate::{services, utils, Error, Result, Ruma}; + #[derive(Debug, Deserialize)] struct Claims { - sub: String, - //exp: usize, + sub: String, + //exp: usize, } /// # `GET /_matrix/client/v3/login` /// -/// Get the supported login types of this server. One of these should be used as the `type` field -/// when logging in. -pub async fn get_login_types_route( - _body: Ruma, -) -> Result { - Ok(get_login_types::v3::Response::new(vec![ - get_login_types::v3::LoginType::Password(PasswordLoginType::default()), - get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()), - ])) +/// Get the supported login types of this server. One of these should be used as +/// the `type` field when logging in. +pub async fn get_login_types_route(_body: Ruma) -> Result { + Ok(get_login_types::v3::Response::new(vec![ + get_login_types::v3::LoginType::Password(PasswordLoginType::default()), + get_login_types::v3::LoginType::ApplicationService(ApplicationServiceLoginType::default()), + ])) } /// # `POST /_matrix/client/v3/login` /// -/// Authenticates the user and returns an access token it can use in subsequent requests. +/// Authenticates the user and returns an access token it can use in subsequent +/// requests. /// -/// - The user needs to authenticate using their password (or if enabled using a json web token) +/// - The user needs to authenticate using their password (or if enabled using a +/// json web token) /// - If `device_id` is known: invalidates old access token of that device /// - If `device_id` is unknown: creates a new device /// - Returns access token that is associated with the user and device /// -/// Note: You can use [`GET /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see +/// Note: You can use [`GET +/// /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see /// supported login types. pub async fn login_route(body: Ruma) -> Result { - // Validate login method - // TODO: Other login methods - let user_id = match &body.login_info { - #[allow(deprecated)] - login::v3::LoginInfo::Password(login::v3::Password { - identifier, - password, - user, - .. - }) => { - debug!("Got password login type"); - let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { - debug!("Using username from identifier field"); - user_id.to_lowercase() - } else if let Some(user_id) = user { - warn!("User \"{}\" is attempting to login with the deprecated \"user\" field at \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is destined to be removed in a future Matrix release.", user_id); - user_id.to_lowercase() - } else { - warn!("Bad login type: {:?}", &body.login_info); - return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); - }; + // Validate login method + // TODO: Other login methods + let user_id = match &body.login_info { + #[allow(deprecated)] + login::v3::LoginInfo::Password(login::v3::Password { + identifier, + password, + user, + .. + }) => { + debug!("Got password login type"); + let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { + debug!("Using username from identifier field"); + user_id.to_lowercase() + } else if let Some(user_id) = user { + warn!( + "User \"{}\" is attempting to login with the deprecated \"user\" field at \ + \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is \ + destined to be removed in a future Matrix release.", + user_id + ); + user_id.to_lowercase() + } else { + warn!("Bad login type: {:?}", &body.login_info); + return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); + }; - let user_id = - UserId::parse_with_server_name(username, services().globals.server_name()) - .map_err(|e| { - warn!("Failed to parse username from user logging in: {}", e); - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") - })?; + let user_id = UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| { + warn!("Failed to parse username from user logging in: {}", e); + Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") + })?; - let hash = services() - .users - .password_hash(&user_id)? - .ok_or(Error::BadRequest( - ErrorKind::Forbidden, - "Wrong username or password.", - ))?; + let hash = services() + .users + .password_hash(&user_id)? + .ok_or(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password."))?; - if hash.is_empty() { - return Err(Error::BadRequest( - ErrorKind::UserDeactivated, - "The user has been deactivated", - )); - } + if hash.is_empty() { + return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated")); + } - let Ok(parsed_hash) = PasswordHash::new(&hash) else { - error!("error while hashing user {}", user_id); - return Err(Error::BadServerResponse("could not hash")); - }; + let Ok(parsed_hash) = PasswordHash::new(&hash) else { + error!("error while hashing user {}", user_id); + return Err(Error::BadServerResponse("could not hash")); + }; - let hash_matches = services() - .globals - .argon - .verify_password(password.as_bytes(), &parsed_hash) - .is_ok(); + let hash_matches = services().globals.argon.verify_password(password.as_bytes(), &parsed_hash).is_ok(); - if !hash_matches { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Wrong username or password.", - )); - } + if !hash_matches { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password.")); + } - user_id - } - login::v3::LoginInfo::Token(login::v3::Token { token }) => { - debug!("Got token login type"); - if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { - let token = jsonwebtoken::decode::( - token, - jwt_decoding_key, - &jsonwebtoken::Validation::default(), - ) - .map_err(|e| { - warn!("Failed to parse JWT token from user logging in: {}", e); - Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.") - })?; + user_id + }, + login::v3::LoginInfo::Token(login::v3::Token { + token, + }) => { + debug!("Got token login type"); + if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() { + let token = + jsonwebtoken::decode::(token, jwt_decoding_key, &jsonwebtoken::Validation::default()) + .map_err(|e| { + warn!("Failed to parse JWT token from user logging in: {}", e); + Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.") + })?; - let username = token.claims.sub.to_lowercase(); + let username = token.claims.sub.to_lowercase(); - UserId::parse_with_server_name(username, services().globals.server_name()).map_err( - |e| { - warn!("Failed to parse username from user logging in: {}", e); - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") - }, - )? - } else { - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Token login is not supported (server has no jwt decoding key).", - )); - } - } - #[allow(deprecated)] - login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService { - identifier, - user, - }) => { - debug!("Got appservice login type"); - if !body.from_appservice { - info!("User tried logging in as an appservice, but request body is not from a known/registered appservice"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Forbidden login type.", - )); - }; - let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { - user_id.to_lowercase() - } else if let Some(user_id) = user { - warn!("Appservice \"{}\" is attempting to login with the deprecated \"user\" field at \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is destined to be removed in a future Matrix release.", user_id); - user_id.to_lowercase() - } else { - return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); - }; + UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| { + warn!("Failed to parse username from user logging in: {}", e); + Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") + })? + } else { + return Err(Error::BadRequest( + ErrorKind::Unknown, + "Token login is not supported (server has no jwt decoding key).", + )); + } + }, + #[allow(deprecated)] + login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService { + identifier, + user, + }) => { + debug!("Got appservice login type"); + if !body.from_appservice { + info!( + "User tried logging in as an appservice, but request body is not from a known/registered \ + appservice" + ); + return Err(Error::BadRequest(ErrorKind::Forbidden, "Forbidden login type.")); + }; + let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier { + user_id.to_lowercase() + } else if let Some(user_id) = user { + warn!( + "Appservice \"{}\" is attempting to login with the deprecated \"user\" field at \ + \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is \ + destined to be removed in a future Matrix release.", + user_id + ); + user_id.to_lowercase() + } else { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type.")); + }; - UserId::parse_with_server_name(username, services().globals.server_name()).map_err( - |e| { - warn!("Failed to parse username from appservice logging in: {}", e); - Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") - }, - )? - } - _ => { - warn!("Unsupported or unknown login type: {:?}", &body.login_info); - debug!("JSON body: {:?}", &body.json_body); - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Unsupported or unknown login type.", - )); - } - }; + UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| { + warn!("Failed to parse username from appservice logging in: {}", e); + Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.") + })? + }, + _ => { + warn!("Unsupported or unknown login type: {:?}", &body.login_info); + debug!("JSON body: {:?}", &body.json_body); + return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported or unknown login type.")); + }, + }; - // Generate new device id if the user didn't specify one - let device_id = body - .device_id - .clone() - .unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into()); + // Generate new device id if the user didn't specify one + let device_id = body.device_id.clone().unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into()); - // Generate a new token for the device - let token = utils::random_string(TOKEN_LENGTH); + // Generate a new token for the device + let token = utils::random_string(TOKEN_LENGTH); - // Determine if device_id was provided and exists in the db for this user - let device_exists = body.device_id.as_ref().map_or(false, |device_id| { - services() - .users - .all_device_ids(&user_id) - .any(|x| x.as_ref().map_or(false, |v| v == device_id)) - }); + // Determine if device_id was provided and exists in the db for this user + let device_exists = body.device_id.as_ref().map_or(false, |device_id| { + services().users.all_device_ids(&user_id).any(|x| x.as_ref().map_or(false, |v| v == device_id)) + }); - if device_exists { - services().users.set_token(&user_id, &device_id, &token)?; - } else { - services().users.create_device( - &user_id, - &device_id, - &token, - body.initial_device_display_name.clone(), - )?; - } + if device_exists { + services().users.set_token(&user_id, &device_id, &token)?; + } else { + services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?; + } - // send client well-known if specified so the client knows to reconfigure itself - let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new( - services() - .globals - .well_known_client() - .to_owned() - .unwrap_or_else(|| "".to_owned()), - )); + // send client well-known if specified so the client knows to reconfigure itself + let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new( + services().globals.well_known_client().to_owned().unwrap_or_else(|| "".to_owned()), + )); - info!("{} logged in", user_id); + info!("{} logged in", user_id); - // home_server is deprecated but apparently must still be sent despite it being deprecated over 6 years ago. - // initially i thought this macro was unnecessary, but ruma uses this same macro for the same reason so... - #[allow(deprecated)] - Ok(login::v3::Response { - user_id, - access_token: token, - device_id, - well_known: { - if client_discovery_info.homeserver.base_url.as_str() == "" { - None - } else { - Some(client_discovery_info) - } - }, - expires_in: None, - home_server: Some(services().globals.server_name().to_owned()), - refresh_token: None, - }) + // home_server is deprecated but apparently must still be sent despite it being + // deprecated over 6 years ago. initially i thought this macro was unnecessary, + // but ruma uses this same macro for the same reason so... + #[allow(deprecated)] + Ok(login::v3::Response { + user_id, + access_token: token, + device_id, + well_known: { + if client_discovery_info.homeserver.base_url.as_str() == "" { + None + } else { + Some(client_discovery_info) + } + }, + expires_in: None, + home_server: Some(services().globals.server_name().to_owned()), + refresh_token: None, + }) } /// # `POST /_matrix/client/v3/logout` @@ -248,19 +223,20 @@ pub async fn login_route(body: Ruma) -> Result) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_ref().expect("user is authenticated"); - services().users.remove_device(sender_user, sender_device)?; + services().users.remove_device(sender_user, sender_device)?; - // send device list update for user after logout - services().users.mark_device_key_update(sender_user)?; + // send device list update for user after logout + services().users.mark_device_key_update(sender_user)?; - Ok(logout::v3::Response::new()) + Ok(logout::v3::Response::new()) } /// # `POST /_matrix/client/r0/logout/all` @@ -268,23 +244,23 @@ pub async fn logout_route(body: Ruma) -> Result, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +/// Note: This is equivalent to calling [`GET +/// /_matrix/client/r0/logout`](fn.logout_route.html) from each device of this +/// user. +pub async fn logout_all_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - for device_id in services().users.all_device_ids(sender_user).flatten() { - services().users.remove_device(sender_user, &device_id)?; - } + for device_id in services().users.all_device_ids(sender_user).flatten() { + services().users.remove_device(sender_user, &device_id)?; + } - // send device list update for user after logout - services().users.mark_device_key_update(sender_user)?; + // send device list update for user after logout + services().users.mark_device_key_update(sender_user)?; - Ok(logout_all::v3::Response::new()) + Ok(logout_all::v3::Response::new()) } diff --git a/src/api/client_server/space.rs b/src/api/client_server/space.rs index e2ea8c34..ac6139a5 100644 --- a/src/api/client_server/space.rs +++ b/src/api/client_server/space.rs @@ -1,34 +1,19 @@ -use crate::{services, Result, Ruma}; use ruma::api::client::space::get_hierarchy; +use crate::{services, Result, Ruma}; + /// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy`` /// -/// Paginates over the space tree in a depth-first manner to locate child rooms of a given space. -pub async fn get_hierarchy_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +/// Paginates over the space tree in a depth-first manner to locate child rooms +/// of a given space. +pub async fn get_hierarchy_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let skip = body - .from - .as_ref() - .and_then(|s| s.parse::().ok()) - .unwrap_or(0); + let skip = body.from.as_ref().and_then(|s| s.parse::().ok()).unwrap_or(0); - let limit = body.limit.map_or(10, u64::from).min(100) as usize; + let limit = body.limit.map_or(10, u64::from).min(100) as usize; - let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself + let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself - services() - .rooms - .spaces - .get_hierarchy( - sender_user, - &body.room_id, - limit, - skip, - max_depth, - body.suggested_only, - ) - .await + services().rooms.spaces.get_hierarchy(sender_user, &body.room_id, limit, skip, max_depth, body.suggested_only).await } diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index 1d2f5e9b..6dfb0fcc 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -1,42 +1,44 @@ use std::sync::Arc; -use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse}; use ruma::{ - api::client::{ - error::ErrorKind, - state::{get_state_events, get_state_events_for_key, send_state_event}, - }, - events::{ - room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType, - }, - serde::Raw, - EventId, RoomId, UserId, + api::client::{ + error::ErrorKind, + state::{get_state_events, get_state_events_for_key, send_state_event}, + }, + events::{room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType}, + serde::Raw, + EventId, RoomId, UserId, }; use tracing::{error, log::warn}; +use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse}; + /// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}` /// /// Sends a state event into the room. /// /// - The only requirement for the content is that it has to be valid json -/// - Tries to send the event into the room, auth rules will determine if it is allowed +/// - Tries to send the event into the room, auth rules will determine if it is +/// allowed /// - If event is new canonical_alias: Rejects if alias is incorrect pub async fn send_state_event_for_key_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event_id = send_state_event_for_key_helper( - sender_user, - &body.room_id, - &body.event_type, - &body.body.body, // Yes, I hate it too - body.state_key.clone(), - ) - .await?; + let event_id = send_state_event_for_key_helper( + sender_user, + &body.room_id, + &body.event_type, + &body.body.body, // Yes, I hate it too + body.state_key.clone(), + ) + .await?; - let event_id = (*event_id).to_owned(); - Ok(send_state_event::v3::Response { event_id }) + let event_id = (*event_id).to_owned(); + Ok(send_state_event::v3::Response { + event_id, + }) } /// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}` @@ -44,249 +46,209 @@ pub async fn send_state_event_for_key_route( /// Sends a state event into the room. /// /// - The only requirement for the content is that it has to be valid json -/// - Tries to send the event into the room, auth rules will determine if it is allowed +/// - Tries to send the event into the room, auth rules will determine if it is +/// allowed /// - If event is new canonical_alias: Rejects if alias is incorrect pub async fn send_state_event_for_empty_key_route( - body: Ruma, + body: Ruma, ) -> Result> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - // Forbid m.room.encryption if encryption is disabled - if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Encryption has been disabled", - )); - } + // Forbid m.room.encryption if encryption is disabled + if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Encryption has been disabled")); + } - let event_id = send_state_event_for_key_helper( - sender_user, - &body.room_id, - &body.event_type.to_string().into(), - &body.body.body, - body.state_key.clone(), - ) - .await?; + let event_id = send_state_event_for_key_helper( + sender_user, + &body.room_id, + &body.event_type.to_string().into(), + &body.body.body, + body.state_key.clone(), + ) + .await?; - let event_id = (*event_id).to_owned(); - Ok(send_state_event::v3::Response { event_id }.into()) + let event_id = (*event_id).to_owned(); + Ok(send_state_event::v3::Response { + event_id, + } + .into()) } /// # `GET /_matrix/client/r0/rooms/{roomid}/state` /// /// Get all state events for a room. /// -/// - If not joined: Only works if current room history visibility is world readable +/// - If not joined: Only works if current room history visibility is world +/// readable pub async fn get_state_events_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .rooms - .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view the room state.", - )); - } + if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view the room state.", + )); + } - Ok(get_state_events::v3::Response { - room_state: services() - .rooms - .state_accessor - .room_state_full(&body.room_id) - .await? - .values() - .map(|pdu| pdu.to_state_event()) - .collect(), - }) + Ok(get_state_events::v3::Response { + room_state: services() + .rooms + .state_accessor + .room_state_full(&body.room_id) + .await? + .values() + .map(|pdu| pdu.to_state_event()) + .collect(), + }) } /// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}/{stateKey}` /// /// Get single state event of a room with the specified state key. -/// The optional query parameter `?format=event|content` allows returning the full room state event -/// or just the state event's content (default behaviour) +/// The optional query parameter `?format=event|content` allows returning the +/// full room state event or just the state event's content (default behaviour) /// -/// - If not joined: Only works if current room history visibility is world readable +/// - If not joined: Only works if current room history visibility is world +/// readable pub async fn get_state_events_for_key_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .rooms - .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view the room state.", - )); - } + if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view the room state.", + )); + } - let event = services() - .rooms - .state_accessor - .room_state_get(&body.room_id, &body.event_type, &body.state_key)? - .ok_or_else(|| { - warn!( - "State event {:?} not found in room {:?}", - &body.event_type, &body.room_id - ); - Error::BadRequest(ErrorKind::NotFound, "State event not found.") - })?; - if body - .format - .as_ref() - .is_some_and(|f| f.to_lowercase().eq("event")) - { - Ok(get_state_events_for_key::v3::Response { - content: None, - event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { - error!("Invalid room state event in database: {}", e); - Error::bad_database("Invalid room state event in database") - })?, - }) - } else { - Ok(get_state_events_for_key::v3::Response { - content: Some(serde_json::from_str(event.content.get()).map_err(|e| { - error!("Invalid room state event content in database: {}", e); - Error::bad_database("Invalid room state event content in database") - })?), - event: None, - }) - } + let event = + services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, &body.state_key)?.ok_or_else( + || { + warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); + Error::BadRequest(ErrorKind::NotFound, "State event not found.") + }, + )?; + if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) { + Ok(get_state_events_for_key::v3::Response { + content: None, + event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { + error!("Invalid room state event in database: {}", e); + Error::bad_database("Invalid room state event in database") + })?, + }) + } else { + Ok(get_state_events_for_key::v3::Response { + content: Some(serde_json::from_str(event.content.get()).map_err(|e| { + error!("Invalid room state event content in database: {}", e); + Error::bad_database("Invalid room state event content in database") + })?), + event: None, + }) + } } /// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}` /// /// Get single state event of a room. -/// The optional query parameter `?format=event|content` allows returning the full room state event -/// or just the state event's content (default behaviour) +/// The optional query parameter `?format=event|content` allows returning the +/// full room state event or just the state event's content (default behaviour) /// -/// - If not joined: Only works if current room history visibility is world readable +/// - If not joined: Only works if current room history visibility is world +/// readable pub async fn get_state_events_for_empty_key_route( - body: Ruma, + body: Ruma, ) -> Result> { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .rooms - .state_accessor - .user_can_see_state_events(sender_user, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You don't have permission to view the room state.", - )); - } + if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You don't have permission to view the room state.", + )); + } - let event = services() - .rooms - .state_accessor - .room_state_get(&body.room_id, &body.event_type, "")? - .ok_or_else(|| { - warn!( - "State event {:?} not found in room {:?}", - &body.event_type, &body.room_id - ); - Error::BadRequest(ErrorKind::NotFound, "State event not found.") - })?; + let event = + services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, "")?.ok_or_else(|| { + warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id); + Error::BadRequest(ErrorKind::NotFound, "State event not found.") + })?; - if body - .format - .as_ref() - .is_some_and(|f| f.to_lowercase().eq("event")) - { - Ok(get_state_events_for_key::v3::Response { - content: None, - event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { - error!("Invalid room state event in database: {}", e); - Error::bad_database("Invalid room state event in database") - })?, - } - .into()) - } else { - Ok(get_state_events_for_key::v3::Response { - content: Some(serde_json::from_str(event.content.get()).map_err(|e| { - error!("Invalid room state event content in database: {}", e); - Error::bad_database("Invalid room state event content in database") - })?), - event: None, - } - .into()) - } + if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) { + Ok(get_state_events_for_key::v3::Response { + content: None, + event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| { + error!("Invalid room state event in database: {}", e); + Error::bad_database("Invalid room state event in database") + })?, + } + .into()) + } else { + Ok(get_state_events_for_key::v3::Response { + content: Some(serde_json::from_str(event.content.get()).map_err(|e| { + error!("Invalid room state event content in database: {}", e); + Error::bad_database("Invalid room state event content in database") + })?), + event: None, + } + .into()) + } } async fn send_state_event_for_key_helper( - sender: &UserId, - room_id: &RoomId, - event_type: &StateEventType, - json: &Raw, - state_key: String, + sender: &UserId, room_id: &RoomId, event_type: &StateEventType, json: &Raw, state_key: String, ) -> Result> { - let sender_user = sender; + let sender_user = sender; - // TODO: Review this check, error if event is unparsable, use event type, allow alias if it - // previously existed - if let Ok(canonical_alias) = - serde_json::from_str::(json.json().get()) - { - let mut aliases = canonical_alias.alt_aliases.clone(); + // TODO: Review this check, error if event is unparsable, use event type, allow + // alias if it previously existed + if let Ok(canonical_alias) = serde_json::from_str::(json.json().get()) { + let mut aliases = canonical_alias.alt_aliases.clone(); - if let Some(alias) = canonical_alias.alias { - aliases.push(alias); - } + if let Some(alias) = canonical_alias.alias { + aliases.push(alias); + } - for alias in aliases { - if alias.server_name() != services().globals.server_name() - || services() - .rooms - .alias - .resolve_local_alias(&alias)? - .filter(|room| room == room_id) // Make sure it's the right room - .is_none() - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You are only allowed to send canonical_alias \ - events when it's aliases already exists", - )); - } - } - } + for alias in aliases { + if alias.server_name() != services().globals.server_name() + || services() + .rooms + .alias + .resolve_local_alias(&alias)? + .filter(|room| room == room_id) // Make sure it's the right room + .is_none() + { + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "You are only allowed to send canonical_alias events when it's aliases already exists", + )); + } + } + } - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default()); + let state_lock = mutex_state.lock().await; - let event_id = services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: event_type.to_string().into(), - content: serde_json::from_str(json.json().get()).expect("content is valid json"), - unsigned: None, - state_key: Some(state_key), - redacts: None, - }, - sender_user, - room_id, - &state_lock, - ) - .await?; + let event_id = services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: event_type.to_string().into(), + content: serde_json::from_str(json.json().get()).expect("content is valid json"), + unsigned: None, + state_key: Some(state_key), + redacts: None, + }, + sender_user, + room_id, + &state_lock, + ) + .await?; - Ok(event_id) + Ok(event_id) } diff --git a/src/api/client_server/sync.rs b/src/api/client_server/sync.rs index 5ec6c26c..28c2793d 100644 --- a/src/api/client_server/sync.rs +++ b/src/api/client_server/sync.rs @@ -1,42 +1,43 @@ -use crate::{ - service::rooms::timeline::PduCount, services, Error, PduEvent, Result, Ruma, RumaResponse, -}; -use ruma::{ - api::client::{ - filter::{FilterDefinition, LazyLoadOptions}, - sync::sync_events::{ - self, - v3::{ - Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, - LeftRoom, Presence, RoomAccountData, RoomSummary, Rooms, State, Timeline, ToDevice, - }, - v4::SlidingOp, - DeviceLists, UnreadNotificationsCount, - }, - uiaa::UiaaResponse, - }, - events::{ - presence::PresenceEvent, - room::member::{MembershipState, RoomMemberEventContent}, - StateEventType, TimelineEventType, - }, - serde::Raw, - uint, DeviceId, OwnedDeviceId, OwnedUserId, RoomId, UInt, UserId, -}; use std::{ - cmp::Ordering, - collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, - sync::Arc, - time::Duration, + cmp::Ordering, + collections::{hash_map::Entry, BTreeMap, BTreeSet, HashMap, HashSet}, + sync::Arc, + time::Duration, +}; + +use ruma::{ + api::client::{ + filter::{FilterDefinition, LazyLoadOptions}, + sync::sync_events::{ + self, + v3::{ + Ephemeral, Filter, GlobalAccountData, InviteState, InvitedRoom, JoinedRoom, LeftRoom, Presence, + RoomAccountData, RoomSummary, Rooms, State, Timeline, ToDevice, + }, + v4::SlidingOp, + DeviceLists, UnreadNotificationsCount, + }, + uiaa::UiaaResponse, + }, + events::{ + presence::PresenceEvent, + room::member::{MembershipState, RoomMemberEventContent}, + StateEventType, TimelineEventType, + }, + serde::Raw, + uint, DeviceId, OwnedDeviceId, OwnedUserId, RoomId, UInt, UserId, }; use tokio::sync::watch::Sender; use tracing::error; +use crate::{service::rooms::timeline::PduCount, services, Error, PduEvent, Result, Ruma, RumaResponse}; + /// # `GET /_matrix/client/r0/sync` /// /// Synchronize the client's state with the latest state on the server. /// -/// - This endpoint takes a `since` parameter which should be the `next_batch` value from a +/// - This endpoint takes a `since` parameter which should be the `next_batch` +/// value from a /// previous request for incremental syncs. /// /// Calling this endpoint without a `since` parameter returns: @@ -45,1730 +46,1397 @@ use tracing::error; /// - Joined and invited member counts, heroes /// - All state events /// -/// Calling this endpoint with a `since` parameter from a previous `next_batch` returns: -/// For joined rooms: +/// Calling this endpoint with a `since` parameter from a previous `next_batch` +/// returns: For joined rooms: /// - Some of the most recent events of each timeline that happened after since -/// - If user joined the room after since: All state events (unless lazy loading is activated) and +/// - If user joined the room after since: All state events (unless lazy loading +/// is activated) and /// all device list updates in that room -/// - If the user was already in the room: A list of all events that are in the state now, but were +/// - If the user was already in the room: A list of all events that are in the +/// state now, but were /// not in the state at `since` -/// - If the state we send contains a member event: Joined and invited member counts, heroes +/// - If the state we send contains a member event: Joined and invited member +/// counts, heroes /// - Device list updates that happened after `since` -/// - If there are events in the timeline we send or the user send updated his read mark: Notification counts +/// - If there are events in the timeline we send or the user send updated his +/// read mark: Notification counts /// - EDUs that are active now (read receipts, typing updates, presence) /// - TODO: Allow multiple sync streams to support Pantalaimon /// /// For invited rooms: -/// - If the user was invited after `since`: A subset of the state of the room at the point of the invite +/// - If the user was invited after `since`: A subset of the state of the room +/// at the point of the invite /// /// For left rooms: -/// - If the user left after `since`: prev_batch token, empty state (TODO: subset of the state at the point of the leave) +/// - If the user left after `since`: prev_batch token, empty state (TODO: +/// subset of the state at the point of the leave) /// -/// - Sync is handled in an async task, multiple requests from the same device with the same +/// - Sync is handled in an async task, multiple requests from the same device +/// with the same /// `since` will be cached pub async fn sync_events_route( - body: Ruma, + body: Ruma, ) -> Result> { - let sender_user = body.sender_user.expect("user is authenticated"); - let sender_device = body.sender_device.expect("user is authenticated"); - let body = body.body; + let sender_user = body.sender_user.expect("user is authenticated"); + let sender_device = body.sender_device.expect("user is authenticated"); + let body = body.body; - let mut rx = match services() - .globals - .sync_receivers - .write() - .unwrap() - .entry((sender_user.clone(), sender_device.clone())) - { - Entry::Vacant(v) => { - let (tx, rx) = tokio::sync::watch::channel(None); + let mut rx = + match services().globals.sync_receivers.write().unwrap().entry((sender_user.clone(), sender_device.clone())) { + Entry::Vacant(v) => { + let (tx, rx) = tokio::sync::watch::channel(None); - v.insert((body.since.clone(), rx.clone())); + v.insert((body.since.clone(), rx.clone())); - tokio::spawn(sync_helper_wrapper( - sender_user.clone(), - sender_device.clone(), - body, - tx, - )); + tokio::spawn(sync_helper_wrapper(sender_user.clone(), sender_device.clone(), body, tx)); - rx - } - Entry::Occupied(mut o) => { - if o.get().0 != body.since { - let (tx, rx) = tokio::sync::watch::channel(None); + rx + }, + Entry::Occupied(mut o) => { + if o.get().0 != body.since { + let (tx, rx) = tokio::sync::watch::channel(None); - o.insert((body.since.clone(), rx.clone())); + o.insert((body.since.clone(), rx.clone())); - tokio::spawn(sync_helper_wrapper( - sender_user.clone(), - sender_device.clone(), - body, - tx, - )); + tokio::spawn(sync_helper_wrapper(sender_user.clone(), sender_device.clone(), body, tx)); - rx - } else { - o.get().1.clone() - } - } - }; + rx + } else { + o.get().1.clone() + } + }, + }; - let we_have_to_wait = rx.borrow().is_none(); - if we_have_to_wait { - if let Err(e) = rx.changed().await { - error!("Error waiting for sync: {}", e); - } - } + let we_have_to_wait = rx.borrow().is_none(); + if we_have_to_wait { + if let Err(e) = rx.changed().await { + error!("Error waiting for sync: {}", e); + } + } - let result = match rx - .borrow() - .as_ref() - .expect("When sync channel changes it's always set to some") - { - Ok(response) => Ok(response.clone()), - Err(error) => Err(error.to_response()), - }; + let result = match rx.borrow().as_ref().expect("When sync channel changes it's always set to some") { + Ok(response) => Ok(response.clone()), + Err(error) => Err(error.to_response()), + }; - result + result } async fn sync_helper_wrapper( - sender_user: OwnedUserId, - sender_device: OwnedDeviceId, - body: sync_events::v3::Request, - tx: Sender>>, + sender_user: OwnedUserId, sender_device: OwnedDeviceId, body: sync_events::v3::Request, + tx: Sender>>, ) { - let since = body.since.clone(); + let since = body.since.clone(); - let r = sync_helper(sender_user.clone(), sender_device.clone(), body).await; + let r = sync_helper(sender_user.clone(), sender_device.clone(), body).await; - if let Ok((_, caching_allowed)) = r { - if !caching_allowed { - match services() - .globals - .sync_receivers - .write() - .unwrap() - .entry((sender_user, sender_device)) - { - Entry::Occupied(o) => { - // Only remove if the device didn't start a different /sync already - if o.get().0 == since { - o.remove(); - } - } - Entry::Vacant(_) => {} - } - } - } + if let Ok((_, caching_allowed)) = r { + if !caching_allowed { + match services().globals.sync_receivers.write().unwrap().entry((sender_user, sender_device)) { + Entry::Occupied(o) => { + // Only remove if the device didn't start a different /sync already + if o.get().0 == since { + o.remove(); + } + }, + Entry::Vacant(_) => {}, + } + } + } - let _ = tx.send(Some(r.map(|(r, _)| r))); + let _ = tx.send(Some(r.map(|(r, _)| r))); } async fn sync_helper( - sender_user: OwnedUserId, - sender_device: OwnedDeviceId, - body: sync_events::v3::Request, - // bool = caching allowed + sender_user: OwnedUserId, + sender_device: OwnedDeviceId, + body: sync_events::v3::Request, + // bool = caching allowed ) -> Result<(sync_events::v3::Response, bool), Error> { - // Presence update - if services().globals.allow_local_presence() { - services() - .rooms - .edus - .presence - .ping_presence(&sender_user, body.set_presence)?; - } + // Presence update + if services().globals.allow_local_presence() { + services().rooms.edus.presence.ping_presence(&sender_user, body.set_presence)?; + } - // Setup watchers, so if there's no response, we can wait for them - let watcher = services().globals.watch(&sender_user, &sender_device); + // Setup watchers, so if there's no response, we can wait for them + let watcher = services().globals.watch(&sender_user, &sender_device); - let next_batch = services().globals.current_count()?; - let next_batchcount = PduCount::Normal(next_batch); - let next_batch_string = next_batch.to_string(); + let next_batch = services().globals.current_count()?; + let next_batchcount = PduCount::Normal(next_batch); + let next_batch_string = next_batch.to_string(); - // Load filter - let filter = match body.filter { - None => FilterDefinition::default(), - Some(Filter::FilterDefinition(filter)) => filter, - Some(Filter::FilterId(filter_id)) => services() - .users - .get_filter(&sender_user, &filter_id)? - .unwrap_or_default(), - }; + // Load filter + let filter = match body.filter { + None => FilterDefinition::default(), + Some(Filter::FilterDefinition(filter)) => filter, + Some(Filter::FilterId(filter_id)) => services().users.get_filter(&sender_user, &filter_id)?.unwrap_or_default(), + }; - let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options { - LazyLoadOptions::Enabled { - include_redundant_members: redundant, - } => (true, redundant), - LazyLoadOptions::Disabled => (false, false), - }; + let (lazy_load_enabled, lazy_load_send_redundant) = match filter.room.state.lazy_load_options { + LazyLoadOptions::Enabled { + include_redundant_members: redundant, + } => (true, redundant), + LazyLoadOptions::Disabled => (false, false), + }; - let full_state = body.full_state; + let full_state = body.full_state; - let mut joined_rooms = BTreeMap::new(); - let since = body - .since - .as_ref() - .and_then(|string| string.parse().ok()) - .unwrap_or(0); - let sincecount = PduCount::Normal(since); + let mut joined_rooms = BTreeMap::new(); + let since = body.since.as_ref().and_then(|string| string.parse().ok()).unwrap_or(0); + let sincecount = PduCount::Normal(since); - let mut presence_updates = HashMap::new(); - let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in - let mut device_list_updates = HashSet::new(); - let mut device_list_left = HashSet::new(); + let mut presence_updates = HashMap::new(); + let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in + let mut device_list_updates = HashSet::new(); + let mut device_list_left = HashSet::new(); - // Look for device list updates of this account - device_list_updates.extend( - services() - .users - .keys_changed(sender_user.as_ref(), since, None) - .filter_map(std::result::Result::ok), - ); + // Look for device list updates of this account + device_list_updates + .extend(services().users.keys_changed(sender_user.as_ref(), since, None).filter_map(std::result::Result::ok)); - let all_joined_rooms = services() - .rooms - .state_cache - .rooms_joined(&sender_user) - .collect::>(); - for room_id in all_joined_rooms { - let room_id = room_id?; - if let Ok(joined_room) = load_joined_room( - &sender_user, - &sender_device, - &room_id, - since, - sincecount, - next_batch, - next_batchcount, - lazy_load_enabled, - lazy_load_send_redundant, - full_state, - &mut device_list_updates, - &mut left_encrypted_users, - ) - .await - { - if !joined_room.is_empty() { - joined_rooms.insert(room_id.clone(), joined_room); - } + let all_joined_rooms = services().rooms.state_cache.rooms_joined(&sender_user).collect::>(); + for room_id in all_joined_rooms { + let room_id = room_id?; + if let Ok(joined_room) = load_joined_room( + &sender_user, + &sender_device, + &room_id, + since, + sincecount, + next_batch, + next_batchcount, + lazy_load_enabled, + lazy_load_send_redundant, + full_state, + &mut device_list_updates, + &mut left_encrypted_users, + ) + .await + { + if !joined_room.is_empty() { + joined_rooms.insert(room_id.clone(), joined_room); + } - if services().globals.allow_local_presence() { - process_room_presence_updates(&mut presence_updates, &room_id, since).await?; - } - } - } + if services().globals.allow_local_presence() { + process_room_presence_updates(&mut presence_updates, &room_id, since).await?; + } + } + } - let mut left_rooms = BTreeMap::new(); - let all_left_rooms: Vec<_> = services() - .rooms - .state_cache - .rooms_left(&sender_user) - .collect(); - for result in all_left_rooms { - let (room_id, _) = result?; + let mut left_rooms = BTreeMap::new(); + let all_left_rooms: Vec<_> = services().rooms.state_cache.rooms_left(&sender_user).collect(); + for result in all_left_rooms { + let (room_id, _) = result?; - let mut left_state_events = Vec::new(); + let mut left_state_events = Vec::new(); - { - // Get and drop the lock to wait for remaining operations to finish - let mutex_insert = Arc::clone( - services() - .globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().await; - drop(insert_lock); - }; + { + // Get and drop the lock to wait for remaining operations to finish + let mutex_insert = + Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.clone()).or_default()); + let insert_lock = mutex_insert.lock().await; + drop(insert_lock); + }; - let left_count = services() - .rooms - .state_cache - .get_left_count(&room_id, &sender_user)?; + let left_count = services().rooms.state_cache.get_left_count(&room_id, &sender_user)?; - // Left before last sync - if Some(since) >= left_count { - continue; - } + // Left before last sync + if Some(since) >= left_count { + continue; + } - if !services().rooms.metadata.exists(&room_id)? { - // This is just a rejected invite, not a room we know - continue; - } + if !services().rooms.metadata.exists(&room_id)? { + // This is just a rejected invite, not a room we know + continue; + } - let since_shortstatehash = services() - .rooms - .user - .get_token_shortstatehash(&room_id, since)?; + let since_shortstatehash = services().rooms.user.get_token_shortstatehash(&room_id, since)?; - let since_state_ids = match since_shortstatehash { - Some(s) => services().rooms.state_accessor.state_full_ids(s).await?, - None => HashMap::new(), - }; + let since_state_ids = match since_shortstatehash { + Some(s) => services().rooms.state_accessor.state_full_ids(s).await?, + None => HashMap::new(), + }; - let left_event_id = match services().rooms.state_accessor.room_state_get_id( - &room_id, - &StateEventType::RoomMember, - sender_user.as_str(), - )? { - Some(e) => e, - None => { - error!("Left room but no left state event"); - continue; - } - }; + let left_event_id = match services().rooms.state_accessor.room_state_get_id( + &room_id, + &StateEventType::RoomMember, + sender_user.as_str(), + )? { + Some(e) => e, + None => { + error!("Left room but no left state event"); + continue; + }, + }; - let left_shortstatehash = match services() - .rooms - .state_accessor - .pdu_shortstatehash(&left_event_id)? - { - Some(s) => s, - None => { - error!("Leave event has no state"); - continue; - } - }; + let left_shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash(&left_event_id)? { + Some(s) => s, + None => { + error!("Leave event has no state"); + continue; + }, + }; - let mut left_state_ids = services() - .rooms - .state_accessor - .state_full_ids(left_shortstatehash) - .await?; + let mut left_state_ids = services().rooms.state_accessor.state_full_ids(left_shortstatehash).await?; - let leave_shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; + let leave_shortstatekey = + services().rooms.short.get_or_create_shortstatekey(&StateEventType::RoomMember, sender_user.as_str())?; - left_state_ids.insert(leave_shortstatekey, left_event_id); + left_state_ids.insert(leave_shortstatekey, left_event_id); - let mut i = 0; - for (key, id) in left_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { - let (event_type, state_key) = - services().rooms.short.get_statekey_from_short(key)?; + let mut i = 0; + for (key, id) in left_state_ids { + if full_state || since_state_ids.get(&key) != Some(&id) { + let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; - if !lazy_load_enabled + if !lazy_load_enabled || event_type != StateEventType::RoomMember || full_state // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || *sender_user == state_key - { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - } - }; + { + let pdu = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + }, + }; - left_state_events.push(pdu.to_sync_state_event()); + left_state_events.push(pdu.to_sync_state_event()); - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - } - } + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + } + } - left_rooms.insert( - room_id.clone(), - LeftRoom { - account_data: RoomAccountData { events: Vec::new() }, - timeline: Timeline { - limited: false, - prev_batch: Some(next_batch_string.clone()), - events: Vec::new(), - }, - state: State { - events: left_state_events, - }, - }, - ); - } + left_rooms.insert( + room_id.clone(), + LeftRoom { + account_data: RoomAccountData { + events: Vec::new(), + }, + timeline: Timeline { + limited: false, + prev_batch: Some(next_batch_string.clone()), + events: Vec::new(), + }, + state: State { + events: left_state_events, + }, + }, + ); + } - let mut invited_rooms = BTreeMap::new(); - let all_invited_rooms: Vec<_> = services() - .rooms - .state_cache - .rooms_invited(&sender_user) - .collect(); - for result in all_invited_rooms { - let (room_id, invite_state_events) = result?; + let mut invited_rooms = BTreeMap::new(); + let all_invited_rooms: Vec<_> = services().rooms.state_cache.rooms_invited(&sender_user).collect(); + for result in all_invited_rooms { + let (room_id, invite_state_events) = result?; - { - // Get and drop the lock to wait for remaining operations to finish - let mutex_insert = Arc::clone( - services() - .globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().await; - drop(insert_lock); - }; + { + // Get and drop the lock to wait for remaining operations to finish + let mutex_insert = + Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.clone()).or_default()); + let insert_lock = mutex_insert.lock().await; + drop(insert_lock); + }; - let invite_count = services() - .rooms - .state_cache - .get_invite_count(&room_id, &sender_user)?; + let invite_count = services().rooms.state_cache.get_invite_count(&room_id, &sender_user)?; - // Invited before last sync - if Some(since) >= invite_count { - continue; - } + // Invited before last sync + if Some(since) >= invite_count { + continue; + } - invited_rooms.insert( - room_id.clone(), - InvitedRoom { - invite_state: InviteState { - events: invite_state_events, - }, - }, - ); - } + invited_rooms.insert( + room_id.clone(), + InvitedRoom { + invite_state: InviteState { + events: invite_state_events, + }, + }, + ); + } - for user_id in left_encrypted_users { - let dont_share_encrypted_room = services() - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? - .filter_map(std::result::Result::ok) - .filter_map(|other_room_id| { - Some( - services() - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .all(|encrypted| !encrypted); - // If the user doesn't share an encrypted room with the target anymore, we need to tell - // them - if dont_share_encrypted_room { - device_list_left.insert(user_id); - } - } + for user_id in left_encrypted_users { + let dont_share_encrypted_room = services() + .rooms + .user + .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? + .filter_map(std::result::Result::ok) + .filter_map(|other_room_id| { + Some( + services() + .rooms + .state_accessor + .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") + .ok()? + .is_some(), + ) + }) + .all(|encrypted| !encrypted); + // If the user doesn't share an encrypted room with the target anymore, we need + // to tell them + if dont_share_encrypted_room { + device_list_left.insert(user_id); + } + } - // Remove all to-device events the device received *last time* - services() - .users - .remove_to_device_events(&sender_user, &sender_device, since)?; + // Remove all to-device events the device received *last time* + services().users.remove_to_device_events(&sender_user, &sender_device, since)?; - let response = sync_events::v3::Response { - next_batch: next_batch_string, - rooms: Rooms { - leave: left_rooms, - join: joined_rooms, - invite: invited_rooms, - knock: BTreeMap::new(), // TODO - }, - presence: Presence { - events: presence_updates - .into_values() - .map(|v| Raw::new(&v).expect("PresenceEvent always serializes successfully")) - .collect(), - }, - account_data: GlobalAccountData { - events: services() - .account_data - .changes_since(None, &sender_user, since)? - .into_iter() - .filter_map(|(_, v)| { - serde_json::from_str(v.json().get()) - .map_err(|_| Error::bad_database("Invalid account event in database.")) - .ok() - }) - .collect(), - }, - device_lists: DeviceLists { - changed: device_list_updates.into_iter().collect(), - left: device_list_left.into_iter().collect(), - }, - device_one_time_keys_count: services() - .users - .count_one_time_keys(&sender_user, &sender_device)?, - to_device: ToDevice { - events: services() - .users - .get_to_device_events(&sender_user, &sender_device)?, - }, - // Fallback keys are not yet supported - device_unused_fallback_key_types: None, - }; + let response = sync_events::v3::Response { + next_batch: next_batch_string, + rooms: Rooms { + leave: left_rooms, + join: joined_rooms, + invite: invited_rooms, + knock: BTreeMap::new(), // TODO + }, + presence: Presence { + events: presence_updates + .into_values() + .map(|v| Raw::new(&v).expect("PresenceEvent always serializes successfully")) + .collect(), + }, + account_data: GlobalAccountData { + events: services() + .account_data + .changes_since(None, &sender_user, since)? + .into_iter() + .filter_map(|(_, v)| { + serde_json::from_str(v.json().get()) + .map_err(|_| Error::bad_database("Invalid account event in database.")) + .ok() + }) + .collect(), + }, + device_lists: DeviceLists { + changed: device_list_updates.into_iter().collect(), + left: device_list_left.into_iter().collect(), + }, + device_one_time_keys_count: services().users.count_one_time_keys(&sender_user, &sender_device)?, + to_device: ToDevice { + events: services().users.get_to_device_events(&sender_user, &sender_device)?, + }, + // Fallback keys are not yet supported + device_unused_fallback_key_types: None, + }; - // TODO: Retry the endpoint instead of returning (waiting for #118) - if !full_state - && response.rooms.is_empty() - && response.presence.is_empty() - && response.account_data.is_empty() - && response.device_lists.is_empty() - && response.to_device.is_empty() - { - // Hang a few seconds so requests are not spammed - // Stop hanging if new info arrives - let mut duration = body.timeout.unwrap_or_default(); - if duration.as_secs() > 30 { - duration = Duration::from_secs(30); - } - let _ = tokio::time::timeout(duration, watcher).await; - Ok((response, false)) - } else { - Ok((response, since != next_batch)) // Only cache if we made progress - } + // TODO: Retry the endpoint instead of returning (waiting for #118) + if !full_state + && response.rooms.is_empty() + && response.presence.is_empty() + && response.account_data.is_empty() + && response.device_lists.is_empty() + && response.to_device.is_empty() + { + // Hang a few seconds so requests are not spammed + // Stop hanging if new info arrives + let mut duration = body.timeout.unwrap_or_default(); + if duration.as_secs() > 30 { + duration = Duration::from_secs(30); + } + let _ = tokio::time::timeout(duration, watcher).await; + Ok((response, false)) + } else { + Ok((response, since != next_batch)) // Only cache if we made progress + } } async fn process_room_presence_updates( - presence_updates: &mut HashMap, - room_id: &RoomId, - since: u64, + presence_updates: &mut HashMap, room_id: &RoomId, since: u64, ) -> Result<()> { - // Take presence updates from this room - for (user_id, _, presence_event) in services() - .rooms - .edus - .presence - .presence_since(room_id, since) - { - match presence_updates.entry(user_id) { - Entry::Vacant(slot) => { - slot.insert(presence_event); - } - Entry::Occupied(mut slot) => { - let curr_event = slot.get_mut(); - let curr_content = &mut curr_event.content; - let new_content = presence_event.content; + // Take presence updates from this room + for (user_id, _, presence_event) in services().rooms.edus.presence.presence_since(room_id, since) { + match presence_updates.entry(user_id) { + Entry::Vacant(slot) => { + slot.insert(presence_event); + }, + Entry::Occupied(mut slot) => { + let curr_event = slot.get_mut(); + let curr_content = &mut curr_event.content; + let new_content = presence_event.content; - // Update existing presence event with more info - curr_content.presence = new_content.presence; - curr_content.status_msg = new_content - .status_msg - .or_else(|| curr_content.status_msg.take()); - curr_content.last_active_ago = - new_content.last_active_ago.or(curr_content.last_active_ago); - curr_content.displayname = new_content - .displayname - .or_else(|| curr_content.displayname.take()); - curr_content.avatar_url = new_content - .avatar_url - .or_else(|| curr_content.avatar_url.take()); - curr_content.currently_active = new_content - .currently_active - .or(curr_content.currently_active); - } - } - } + // Update existing presence event with more info + curr_content.presence = new_content.presence; + curr_content.status_msg = new_content.status_msg.or_else(|| curr_content.status_msg.take()); + curr_content.last_active_ago = new_content.last_active_ago.or(curr_content.last_active_ago); + curr_content.displayname = new_content.displayname.or_else(|| curr_content.displayname.take()); + curr_content.avatar_url = new_content.avatar_url.or_else(|| curr_content.avatar_url.take()); + curr_content.currently_active = new_content.currently_active.or(curr_content.currently_active); + }, + } + } - Ok(()) + Ok(()) } #[allow(clippy::too_many_arguments)] async fn load_joined_room( - sender_user: &UserId, - sender_device: &DeviceId, - room_id: &RoomId, - since: u64, - sincecount: PduCount, - next_batch: u64, - next_batchcount: PduCount, - lazy_load_enabled: bool, - lazy_load_send_redundant: bool, - full_state: bool, - device_list_updates: &mut HashSet, - left_encrypted_users: &mut HashSet, + sender_user: &UserId, sender_device: &DeviceId, room_id: &RoomId, since: u64, sincecount: PduCount, + next_batch: u64, next_batchcount: PduCount, lazy_load_enabled: bool, lazy_load_send_redundant: bool, + full_state: bool, device_list_updates: &mut HashSet, left_encrypted_users: &mut HashSet, ) -> Result { - { - // Get and drop the lock to wait for remaining operations to finish - // This will make sure the we have all events until next_batch - let mutex_insert = Arc::clone( - services() - .globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().await; - drop(insert_lock); - }; + { + // Get and drop the lock to wait for remaining operations to finish + // This will make sure the we have all events until next_batch + let mutex_insert = + Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.to_owned()).or_default()); + let insert_lock = mutex_insert.lock().await; + drop(insert_lock); + }; - let (timeline_pdus, limited) = load_timeline(sender_user, room_id, sincecount, 10)?; + let (timeline_pdus, limited) = load_timeline(sender_user, room_id, sincecount, 10)?; - let send_notification_counts = !timeline_pdus.is_empty() - || services() - .rooms - .user - .last_notification_read(sender_user, room_id)? - > since; + let send_notification_counts = + !timeline_pdus.is_empty() || services().rooms.user.last_notification_read(sender_user, room_id)? > since; - let mut timeline_users = HashSet::new(); - for (_, event) in &timeline_pdus { - timeline_users.insert(event.sender.as_str().to_owned()); - } + let mut timeline_users = HashSet::new(); + for (_, event) in &timeline_pdus { + timeline_users.insert(event.sender.as_str().to_owned()); + } - services().rooms.lazy_loading.lazy_load_confirm_delivery( - sender_user, - sender_device, - room_id, - sincecount, - )?; + services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, room_id, sincecount)?; - // Database queries: + // Database queries: - let current_shortstatehash = - if let Some(s) = services().rooms.state.get_room_shortstatehash(room_id)? { - s - } else { - error!("Room {} has no state", room_id); - return Err(Error::BadDatabase("Room has no state")); - }; + let current_shortstatehash = if let Some(s) = services().rooms.state.get_room_shortstatehash(room_id)? { + s + } else { + error!("Room {} has no state", room_id); + return Err(Error::BadDatabase("Room has no state")); + }; - let since_shortstatehash = services() - .rooms - .user - .get_token_shortstatehash(room_id, since)?; + let since_shortstatehash = services().rooms.user.get_token_shortstatehash(room_id, since)?; - let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = - if timeline_pdus.is_empty() && since_shortstatehash == Some(current_shortstatehash) { - // No state changes - (Vec::new(), None, None, false, Vec::new()) - } else { - // Calculates joined_member_count, invited_member_count and heroes - let calculate_counts = || { - let joined_member_count = services() - .rooms - .state_cache - .room_joined_count(room_id)? - .unwrap_or(0); - let invited_member_count = services() - .rooms - .state_cache - .room_invited_count(room_id)? - .unwrap_or(0); + let (heroes, joined_member_count, invited_member_count, joined_since_last_sync, state_events) = if timeline_pdus + .is_empty() + && since_shortstatehash == Some(current_shortstatehash) + { + // No state changes + (Vec::new(), None, None, false, Vec::new()) + } else { + // Calculates joined_member_count, invited_member_count and heroes + let calculate_counts = || { + let joined_member_count = services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(0); + let invited_member_count = services().rooms.state_cache.room_invited_count(room_id)?.unwrap_or(0); - // Recalculate heroes (first 5 members) - let mut heroes = Vec::new(); + // Recalculate heroes (first 5 members) + let mut heroes = Vec::new(); - if joined_member_count + invited_member_count <= 5 { - // Go through all PDUs and for each member event, check if the user is still joined or - // invited until we have 5 or we reach the end + if joined_member_count + invited_member_count <= 5 { + // Go through all PDUs and for each member event, check if the user is still + // joined or invited until we have 5 or we reach the end - for hero in services() - .rooms - .timeline - .all_pdus(sender_user, room_id)? - .filter_map(std::result::Result::ok) // Ignore all broken pdus - .filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) - .map(|(_, pdu)| { - let content: RoomMemberEventContent = - serde_json::from_str(pdu.content.get()).map_err(|_| { - Error::bad_database("Invalid member event in database.") - })?; + for hero in services() + .rooms + .timeline + .all_pdus(sender_user, room_id)? + .filter_map(std::result::Result::ok) // Ignore all broken pdus + .filter(|(_, pdu)| pdu.kind == TimelineEventType::RoomMember) + .map(|(_, pdu)| { + let content: RoomMemberEventContent = serde_json::from_str(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid member event in database."))?; - if let Some(state_key) = &pdu.state_key { - let user_id = UserId::parse(state_key.clone()).map_err(|_| { - Error::bad_database("Invalid UserId in member PDU.") - })?; + if let Some(state_key) = &pdu.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - // The membership was and still is invite or join - if matches!( - content.membership, - MembershipState::Join | MembershipState::Invite - ) && (services() - .rooms - .state_cache - .is_joined(&user_id, room_id)? - || services() - .rooms - .state_cache - .is_invited(&user_id, room_id)?) - { - Ok::<_, Error>(Some(state_key.clone())) - } else { - Ok(None) - } - } else { - Ok(None) - } - }) - // Filter out buggy users - .filter_map(std::result::Result::ok) - // Filter for possible heroes - .flatten() - { - if heroes.contains(&hero) || hero == sender_user.as_str() { - continue; - } + // The membership was and still is invite or join + if matches!(content.membership, MembershipState::Join | MembershipState::Invite) + && (services().rooms.state_cache.is_joined(&user_id, room_id)? + || services().rooms.state_cache.is_invited(&user_id, room_id)?) + { + Ok::<_, Error>(Some(state_key.clone())) + } else { + Ok(None) + } + } else { + Ok(None) + } + }) + // Filter out buggy users + .filter_map(std::result::Result::ok) + // Filter for possible heroes + .flatten() + { + if heroes.contains(&hero) || hero == sender_user.as_str() { + continue; + } - heroes.push(hero); - } - } + heroes.push(hero); + } + } - Ok::<_, Error>(( - Some(joined_member_count), - Some(invited_member_count), - heroes, - )) - }; + Ok::<_, Error>((Some(joined_member_count), Some(invited_member_count), heroes)) + }; - let since_sender_member: Option = since_shortstatehash - .and_then(|shortstatehash| { - services() - .rooms - .state_accessor - .state_get( - shortstatehash, - &StateEventType::RoomMember, - sender_user.as_str(), - ) - .transpose() - }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); + let since_sender_member: Option = since_shortstatehash + .and_then(|shortstatehash| { + services() + .rooms + .state_accessor + .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) + .transpose() + }) + .transpose()? + .and_then(|pdu| { + serde_json::from_str(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }); - let joined_since_last_sync = since_sender_member - .map_or(true, |member| member.membership != MembershipState::Join); + let joined_since_last_sync = + since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); - if since_shortstatehash.is_none() || joined_since_last_sync { - // Probably since = 0, we will do an initial sync + if since_shortstatehash.is_none() || joined_since_last_sync { + // Probably since = 0, we will do an initial sync - let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; + let (joined_member_count, invited_member_count, heroes) = calculate_counts()?; - let current_state_ids = services() - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; + let current_state_ids = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; - let mut state_events = Vec::new(); - let mut lazy_loaded = HashSet::new(); + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); - let mut i = 0; - for (shortstatekey, id) in current_state_ids { - let (event_type, state_key) = services() - .rooms - .short - .get_statekey_from_short(shortstatekey)?; + let mut i = 0; + for (shortstatekey, id) in current_state_ids { + let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?; - if event_type != StateEventType::RoomMember { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - } - }; - state_events.push(pdu); + if event_type != StateEventType::RoomMember { + let pdu = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + }, + }; + state_events.push(pdu); - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } else if !lazy_load_enabled + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } else if !lazy_load_enabled || full_state || timeline_users.contains(&state_key) // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 || *sender_user == state_key - { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - } - }; + { + let pdu = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + }, + }; - // This check is in case a bad user ID made it into the database - if let Ok(uid) = UserId::parse(&state_key) { - lazy_loaded.insert(uid); - } - state_events.push(pdu); + // This check is in case a bad user ID made it into the database + if let Ok(uid) = UserId::parse(&state_key) { + lazy_loaded.insert(uid); + } + state_events.push(pdu); - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - } + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + } - // Reset lazy loading because this is an initial sync - services().rooms.lazy_loading.lazy_load_reset( - sender_user, - sender_device, - room_id, - )?; + // Reset lazy loading because this is an initial sync + services().rooms.lazy_loading.lazy_load_reset(sender_user, sender_device, room_id)?; - // The state_events above should contain all timeline_users, let's mark them as lazy - // loaded. - services().rooms.lazy_loading.lazy_load_mark_sent( - sender_user, - sender_device, - room_id, - lazy_loaded, - next_batchcount, - ); + // The state_events above should contain all timeline_users, let's mark them as + // lazy loaded. + services().rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ); - ( - heroes, - joined_member_count, - invited_member_count, - true, - state_events, - ) - } else { - // Incremental /sync - let since_shortstatehash = since_shortstatehash.unwrap(); + (heroes, joined_member_count, invited_member_count, true, state_events) + } else { + // Incremental /sync + let since_shortstatehash = since_shortstatehash.unwrap(); - let mut state_events = Vec::new(); - let mut lazy_loaded = HashSet::new(); + let mut state_events = Vec::new(); + let mut lazy_loaded = HashSet::new(); - if since_shortstatehash != current_shortstatehash { - let current_state_ids = services() - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let since_state_ids = services() - .rooms - .state_accessor - .state_full_ids(since_shortstatehash) - .await?; + if since_shortstatehash != current_shortstatehash { + let current_state_ids = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; + let since_state_ids = services().rooms.state_accessor.state_full_ids(since_shortstatehash).await?; - for (key, id) in current_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - } - }; + for (key, id) in current_state_ids { + if full_state || since_state_ids.get(&key) != Some(&id) { + let pdu = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + }, + }; - if pdu.kind == TimelineEventType::RoomMember { - match UserId::parse( - pdu.state_key - .as_ref() - .expect("State event has state key") - .clone(), - ) { - Ok(state_key_userid) => { - lazy_loaded.insert(state_key_userid); - } - Err(e) => error!("Invalid state key for member event: {}", e), - } - } + if pdu.kind == TimelineEventType::RoomMember { + match UserId::parse(pdu.state_key.as_ref().expect("State event has state key").clone()) { + Ok(state_key_userid) => { + lazy_loaded.insert(state_key_userid); + }, + Err(e) => error!("Invalid state key for member event: {}", e), + } + } - state_events.push(pdu); - tokio::task::yield_now().await; - } - } - } + state_events.push(pdu); + tokio::task::yield_now().await; + } + } + } - for (_, event) in &timeline_pdus { - if lazy_loaded.contains(&event.sender) { - continue; - } + for (_, event) in &timeline_pdus { + if lazy_loaded.contains(&event.sender) { + continue; + } - if !services().rooms.lazy_loading.lazy_load_was_sent_before( - sender_user, - sender_device, - room_id, - &event.sender, - )? || lazy_load_send_redundant - { - if let Some(member_event) = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomMember, - event.sender.as_str(), - )? { - lazy_loaded.insert(event.sender.clone()); - state_events.push(member_event); - } - } - } + if !services().rooms.lazy_loading.lazy_load_was_sent_before( + sender_user, + sender_device, + room_id, + &event.sender, + )? || lazy_load_send_redundant + { + if let Some(member_event) = services().rooms.state_accessor.room_state_get( + room_id, + &StateEventType::RoomMember, + event.sender.as_str(), + )? { + lazy_loaded.insert(event.sender.clone()); + state_events.push(member_event); + } + } + } - services().rooms.lazy_loading.lazy_load_mark_sent( - sender_user, - sender_device, - room_id, - lazy_loaded, - next_batchcount, - ); + services().rooms.lazy_loading.lazy_load_mark_sent( + sender_user, + sender_device, + room_id, + lazy_loaded, + next_batchcount, + ); - let encrypted_room = services() - .rooms - .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); + let encrypted_room = services() + .rooms + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? + .is_some(); - let since_encryption = services().rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; + let since_encryption = + services().rooms.state_accessor.state_get(since_shortstatehash, &StateEventType::RoomEncryption, "")?; - // Calculations: - let new_encrypted_room = encrypted_room && since_encryption.is_none(); + // Calculations: + let new_encrypted_room = encrypted_room && since_encryption.is_none(); - let send_member_count = state_events - .iter() - .any(|event| event.kind == TimelineEventType::RoomMember); + let send_member_count = state_events.iter().any(|event| event.kind == TimelineEventType::RoomMember); - if encrypted_room { - for state_event in &state_events { - if state_event.kind != TimelineEventType::RoomMember { - continue; - } + if encrypted_room { + for state_event in &state_events { + if state_event.kind != TimelineEventType::RoomMember { + continue; + } - if let Some(state_key) = &state_event.state_key { - let user_id = UserId::parse(state_key.clone()).map_err(|_| { - Error::bad_database("Invalid UserId in member PDU.") - })?; + if let Some(state_key) = &state_event.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - if user_id == sender_user { - continue; - } + if user_id == sender_user { + continue; + } - let new_membership = serde_json::from_str::( - state_event.content.get(), - ) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; + let new_membership = serde_json::from_str::(state_event.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database."))? + .membership; - match new_membership { - MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room(sender_user, &user_id, room_id)? { - device_list_updates.insert(user_id); - } - } - MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); - } - _ => {} - } - } - } - } + match new_membership { + MembershipState::Join => { + // A new user joined an encrypted room + if !share_encrypted_room(sender_user, &user_id, room_id)? { + device_list_updates.insert(user_id); + } + }, + MembershipState::Leave => { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert(user_id); + }, + _ => {}, + } + } + } + } - if joined_since_last_sync && encrypted_room || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_updates.extend( - services() - .rooms - .state_cache - .room_members(room_id) - .flatten() - .filter(|user_id| { - // Don't send key updates from the sender to the sender - sender_user != user_id - }) - .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target already - !share_encrypted_room(sender_user, user_id, room_id) - .unwrap_or(false) - }), - ); - } + if joined_since_last_sync && encrypted_room || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined users + device_list_updates.extend( + services() + .rooms + .state_cache + .room_members(room_id) + .flatten() + .filter(|user_id| { + // Don't send key updates from the sender to the sender + sender_user != user_id + }) + .filter(|user_id| { + // Only send keys if the sender doesn't share an encrypted room with the target + // already + !share_encrypted_room(sender_user, user_id, room_id).unwrap_or(false) + }), + ); + } - let (joined_member_count, invited_member_count, heroes) = if send_member_count { - calculate_counts()? - } else { - (None, None, Vec::new()) - }; + let (joined_member_count, invited_member_count, heroes) = if send_member_count { + calculate_counts()? + } else { + (None, None, Vec::new()) + }; - ( - heroes, - joined_member_count, - invited_member_count, - joined_since_last_sync, - state_events, - ) - } - }; + ( + heroes, + joined_member_count, + invited_member_count, + joined_since_last_sync, + state_events, + ) + } + }; - // Look for device list updates in this room - device_list_updates.extend( - services() - .users - .keys_changed(room_id.as_ref(), since, None) - .filter_map(std::result::Result::ok), - ); + // Look for device list updates in this room + device_list_updates + .extend(services().users.keys_changed(room_id.as_ref(), since, None).filter_map(std::result::Result::ok)); - let notification_count = if send_notification_counts { - Some( - services() - .rooms - .user - .notification_count(sender_user, room_id)? - .try_into() - .expect("notification count can't go that high"), - ) - } else { - None - }; + let notification_count = if send_notification_counts { + Some( + services() + .rooms + .user + .notification_count(sender_user, room_id)? + .try_into() + .expect("notification count can't go that high"), + ) + } else { + None + }; - let highlight_count = if send_notification_counts { - Some( - services() - .rooms - .user - .highlight_count(sender_user, room_id)? - .try_into() - .expect("highlight count can't go that high"), - ) - } else { - None - }; + let highlight_count = if send_notification_counts { + Some( + services() + .rooms + .user + .highlight_count(sender_user, room_id)? + .try_into() + .expect("highlight count can't go that high"), + ) + } else { + None + }; - let prev_batch = timeline_pdus - .first() - .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { - Ok(Some(match pdu_count { - PduCount::Backfilled(_) => { - error!("timeline in backfill state?!"); - "0".to_owned() - } - PduCount::Normal(c) => c.to_string(), - })) - })?; + let prev_batch = timeline_pdus.first().map_or(Ok::<_, Error>(None), |(pdu_count, _)| { + Ok(Some(match pdu_count { + PduCount::Backfilled(_) => { + error!("timeline in backfill state?!"); + "0".to_owned() + }, + PduCount::Normal(c) => c.to_string(), + })) + })?; - let room_events: Vec<_> = timeline_pdus - .iter() - .map(|(_, pdu)| pdu.to_sync_room_event()) - .collect(); + let room_events: Vec<_> = timeline_pdus.iter().map(|(_, pdu)| pdu.to_sync_room_event()).collect(); - let mut edus: Vec<_> = services() - .rooms - .edus - .read_receipt - .readreceipts_since(room_id, since) - .filter_map(std::result::Result::ok) // Filter out buggy events - .map(|(_, _, v)| v) - .collect(); + let mut edus: Vec<_> = services() + .rooms + .edus + .read_receipt + .readreceipts_since(room_id, since) + .filter_map(std::result::Result::ok) // Filter out buggy events + .map(|(_, _, v)| v) + .collect(); - if services().rooms.edus.typing.last_typing_update(room_id)? > since { - edus.push( - serde_json::from_str( - &serde_json::to_string(&services().rooms.edus.typing.typings_all(room_id)?) - .expect("event is valid, we just created it"), - ) - .expect("event is valid, we just created it"), - ); - } + if services().rooms.edus.typing.last_typing_update(room_id)? > since { + edus.push( + serde_json::from_str( + &serde_json::to_string(&services().rooms.edus.typing.typings_all(room_id)?) + .expect("event is valid, we just created it"), + ) + .expect("event is valid, we just created it"), + ); + } - // Save the state after this sync so we can send the correct state diff next sync - services().rooms.user.associate_token_shortstatehash( - room_id, - next_batch, - current_shortstatehash, - )?; + // Save the state after this sync so we can send the correct state diff next + // sync + services().rooms.user.associate_token_shortstatehash(room_id, next_batch, current_shortstatehash)?; - Ok(JoinedRoom { - account_data: RoomAccountData { - events: services() - .account_data - .changes_since(Some(room_id), sender_user, since)? - .into_iter() - .filter_map(|(_, v)| { - serde_json::from_str(v.json().get()) - .map_err(|_| Error::bad_database("Invalid account event in database.")) - .ok() - }) - .collect(), - }, - summary: RoomSummary { - heroes, - joined_member_count: joined_member_count.map(|n| (n as u32).into()), - invited_member_count: invited_member_count.map(|n| (n as u32).into()), - }, - unread_notifications: UnreadNotificationsCount { - highlight_count, - notification_count, - }, - timeline: Timeline { - limited: limited || joined_since_last_sync, - prev_batch, - events: room_events, - }, - state: State { - events: state_events - .iter() - .map(|pdu| pdu.to_sync_state_event()) - .collect(), - }, - ephemeral: Ephemeral { events: edus }, - unread_thread_notifications: BTreeMap::new(), - }) + Ok(JoinedRoom { + account_data: RoomAccountData { + events: services() + .account_data + .changes_since(Some(room_id), sender_user, since)? + .into_iter() + .filter_map(|(_, v)| { + serde_json::from_str(v.json().get()) + .map_err(|_| Error::bad_database("Invalid account event in database.")) + .ok() + }) + .collect(), + }, + summary: RoomSummary { + heroes, + joined_member_count: joined_member_count.map(|n| (n as u32).into()), + invited_member_count: invited_member_count.map(|n| (n as u32).into()), + }, + unread_notifications: UnreadNotificationsCount { + highlight_count, + notification_count, + }, + timeline: Timeline { + limited: limited || joined_since_last_sync, + prev_batch, + events: room_events, + }, + state: State { + events: state_events.iter().map(|pdu| pdu.to_sync_state_event()).collect(), + }, + ephemeral: Ephemeral { + events: edus, + }, + unread_thread_notifications: BTreeMap::new(), + }) } fn load_timeline( - sender_user: &UserId, - room_id: &RoomId, - roomsincecount: PduCount, - limit: u64, + sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { - let timeline_pdus; - let limited; - if services() - .rooms - .timeline - .last_timeline_count(sender_user, room_id)? - > roomsincecount - { - let mut non_timeline_pdus = services() - .rooms - .timeline - .pdus_until(sender_user, room_id, PduCount::max())? - .filter_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) - .take_while(|(pducount, _)| pducount > &roomsincecount); + let timeline_pdus; + let limited; + if services().rooms.timeline.last_timeline_count(sender_user, room_id)? > roomsincecount { + let mut non_timeline_pdus = services() + .rooms + .timeline + .pdus_until(sender_user, room_id, PduCount::max())? + .filter_map(|r| { + // Filter out buggy events + if r.is_err() { + error!("Bad pdu in pdus_since: {:?}", r); + } + r.ok() + }) + .take_while(|(pducount, _)| pducount > &roomsincecount); - // Take the last events for the timeline - timeline_pdus = non_timeline_pdus - .by_ref() - .take(limit as usize) - .collect::>() - .into_iter() - .rev() - .collect::>(); + // Take the last events for the timeline + timeline_pdus = + non_timeline_pdus.by_ref().take(limit as usize).collect::>().into_iter().rev().collect::>(); - // They /sync response doesn't always return all messages, so we say the output is - // limited unless there are events in non_timeline_pdus - limited = non_timeline_pdus.next().is_some(); - } else { - timeline_pdus = Vec::new(); - limited = false; - } - Ok((timeline_pdus, limited)) + // They /sync response doesn't always return all messages, so we say the output + // is limited unless there are events in non_timeline_pdus + limited = non_timeline_pdus.next().is_some(); + } else { + timeline_pdus = Vec::new(); + limited = false; + } + Ok((timeline_pdus, limited)) } -fn share_encrypted_room( - sender_user: &UserId, - user_id: &UserId, - ignore_room: &RoomId, -) -> Result { - Ok(services() - .rooms - .user - .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? - .filter_map(std::result::Result::ok) - .filter(|room_id| room_id != ignore_room) - .filter_map(|other_room_id| { - Some( - services() - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .any(|encrypted| encrypted)) +fn share_encrypted_room(sender_user: &UserId, user_id: &UserId, ignore_room: &RoomId) -> Result { + Ok(services() + .rooms + .user + .get_shared_rooms(vec![sender_user.to_owned(), user_id.to_owned()])? + .filter_map(std::result::Result::ok) + .filter(|room_id| room_id != ignore_room) + .filter_map(|other_room_id| { + Some( + services() + .rooms + .state_accessor + .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") + .ok()? + .is_some(), + ) + }) + .any(|encrypted| encrypted)) } pub async fn sync_events_v4_route( - body: Ruma, + body: Ruma, ) -> Result> { - let sender_user = body.sender_user.expect("user is authenticated"); - let sender_device = body.sender_device.expect("user is authenticated"); - let mut body = body.body; - // Setup watchers, so if there's no response, we can wait for them - let watcher = services().globals.watch(&sender_user, &sender_device); + let sender_user = body.sender_user.expect("user is authenticated"); + let sender_device = body.sender_device.expect("user is authenticated"); + let mut body = body.body; + // Setup watchers, so if there's no response, we can wait for them + let watcher = services().globals.watch(&sender_user, &sender_device); - let next_batch = services().globals.next_count()?; + let next_batch = services().globals.next_count()?; - let globalsince = body - .pos - .as_ref() - .and_then(|string| string.parse().ok()) - .unwrap_or(0); + let globalsince = body.pos.as_ref().and_then(|string| string.parse().ok()).unwrap_or(0); - if globalsince == 0 { - if let Some(conn_id) = &body.conn_id { - services().users.forget_sync_request_connection( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - ); - } - } + if globalsince == 0 { + if let Some(conn_id) = &body.conn_id { + services().users.forget_sync_request_connection( + sender_user.clone(), + sender_device.clone(), + conn_id.clone(), + ); + } + } - // Get sticky parameters from cache - let known_rooms = services().users.update_sync_request_with_cache( - sender_user.clone(), - sender_device.clone(), - &mut body, - ); + // Get sticky parameters from cache + let known_rooms = + services().users.update_sync_request_with_cache(sender_user.clone(), sender_device.clone(), &mut body); - let all_joined_rooms = services() - .rooms - .state_cache - .rooms_joined(&sender_user) - .filter_map(std::result::Result::ok) - .collect::>(); + let all_joined_rooms = + services().rooms.state_cache.rooms_joined(&sender_user).filter_map(std::result::Result::ok).collect::>(); - if body.extensions.to_device.enabled.unwrap_or(false) { - services() - .users - .remove_to_device_events(&sender_user, &sender_device, globalsince)?; - } + if body.extensions.to_device.enabled.unwrap_or(false) { + services().users.remove_to_device_events(&sender_user, &sender_device, globalsince)?; + } - let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in - let mut device_list_changes = HashSet::new(); - let mut device_list_left = HashSet::new(); + let mut left_encrypted_users = HashSet::new(); // Users that have left any encrypted rooms the sender was in + let mut device_list_changes = HashSet::new(); + let mut device_list_left = HashSet::new(); - if body.extensions.e2ee.enabled.unwrap_or(false) { - // Look for device list updates of this account - device_list_changes.extend( - services() - .users - .keys_changed(sender_user.as_ref(), globalsince, None) - .filter_map(std::result::Result::ok), - ); + if body.extensions.e2ee.enabled.unwrap_or(false) { + // Look for device list updates of this account + device_list_changes.extend( + services().users.keys_changed(sender_user.as_ref(), globalsince, None).filter_map(std::result::Result::ok), + ); - for room_id in &all_joined_rooms { - let current_shortstatehash = - if let Some(s) = services().rooms.state.get_room_shortstatehash(room_id)? { - s - } else { - error!("Room {} has no state", room_id); - continue; - }; + for room_id in &all_joined_rooms { + let current_shortstatehash = if let Some(s) = services().rooms.state.get_room_shortstatehash(room_id)? { + s + } else { + error!("Room {} has no state", room_id); + continue; + }; - let since_shortstatehash = services() - .rooms - .user - .get_token_shortstatehash(room_id, globalsince)?; + let since_shortstatehash = services().rooms.user.get_token_shortstatehash(room_id, globalsince)?; - let since_sender_member: Option = since_shortstatehash - .and_then(|shortstatehash| { - services() - .rooms - .state_accessor - .state_get( - shortstatehash, - &StateEventType::RoomMember, - sender_user.as_str(), - ) - .transpose() - }) - .transpose()? - .and_then(|pdu| { - serde_json::from_str(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid PDU in database.")) - .ok() - }); + let since_sender_member: Option = since_shortstatehash + .and_then(|shortstatehash| { + services() + .rooms + .state_accessor + .state_get(shortstatehash, &StateEventType::RoomMember, sender_user.as_str()) + .transpose() + }) + .transpose()? + .and_then(|pdu| { + serde_json::from_str(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database.")) + .ok() + }); - let encrypted_room = services() - .rooms - .state_accessor - .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? - .is_some(); + let encrypted_room = services() + .rooms + .state_accessor + .state_get(current_shortstatehash, &StateEventType::RoomEncryption, "")? + .is_some(); - if let Some(since_shortstatehash) = since_shortstatehash { - // Skip if there are only timeline changes - if since_shortstatehash == current_shortstatehash { - continue; - } + if let Some(since_shortstatehash) = since_shortstatehash { + // Skip if there are only timeline changes + if since_shortstatehash == current_shortstatehash { + continue; + } - let since_encryption = services().rooms.state_accessor.state_get( - since_shortstatehash, - &StateEventType::RoomEncryption, - "", - )?; + let since_encryption = services().rooms.state_accessor.state_get( + since_shortstatehash, + &StateEventType::RoomEncryption, + "", + )?; - let joined_since_last_sync = since_sender_member - .map_or(true, |member| member.membership != MembershipState::Join); + let joined_since_last_sync = + since_sender_member.map_or(true, |member| member.membership != MembershipState::Join); - let new_encrypted_room = encrypted_room && since_encryption.is_none(); - if encrypted_room { - let current_state_ids = services() - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; - let since_state_ids = services() - .rooms - .state_accessor - .state_full_ids(since_shortstatehash) - .await?; + let new_encrypted_room = encrypted_room && since_encryption.is_none(); + if encrypted_room { + let current_state_ids = + services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; + let since_state_ids = services().rooms.state_accessor.state_full_ids(since_shortstatehash).await?; - for (key, id) in current_state_ids { - if since_state_ids.get(&key) != Some(&id) { - let pdu = match services().rooms.timeline.get_pdu(&id)? { - Some(pdu) => pdu, - None => { - error!("Pdu in state not found: {}", id); - continue; - } - }; - if pdu.kind == TimelineEventType::RoomMember { - if let Some(state_key) = &pdu.state_key { - let user_id = - UserId::parse(state_key.clone()).map_err(|_| { - Error::bad_database("Invalid UserId in member PDU.") - })?; + for (key, id) in current_state_ids { + if since_state_ids.get(&key) != Some(&id) { + let pdu = match services().rooms.timeline.get_pdu(&id)? { + Some(pdu) => pdu, + None => { + error!("Pdu in state not found: {}", id); + continue; + }, + }; + if pdu.kind == TimelineEventType::RoomMember { + if let Some(state_key) = &pdu.state_key { + let user_id = UserId::parse(state_key.clone()) + .map_err(|_| Error::bad_database("Invalid UserId in member PDU."))?; - if user_id == sender_user { - continue; - } + if user_id == sender_user { + continue; + } - let new_membership = serde_json::from_str::< - RoomMemberEventContent, - >( - pdu.content.get() - ) - .map_err(|_| Error::bad_database("Invalid PDU in database."))? - .membership; + let new_membership = + serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid PDU in database."))? + .membership; - match new_membership { - MembershipState::Join => { - // A new user joined an encrypted room - if !share_encrypted_room( - &sender_user, - &user_id, - room_id, - )? { - device_list_changes.insert(user_id); - } - } - MembershipState::Leave => { - // Write down users that have left encrypted rooms we are in - left_encrypted_users.insert(user_id); - } - _ => {} - } - } - } - } - } - if joined_since_last_sync || new_encrypted_room { - // If the user is in a new encrypted room, give them all joined users - device_list_changes.extend( - services() - .rooms - .state_cache - .room_members(room_id) - .flatten() - .filter(|user_id| { - // Don't send key updates from the sender to the sender - &sender_user != user_id - }) - .filter(|user_id| { - // Only send keys if the sender doesn't share an encrypted room with the target already - !share_encrypted_room(&sender_user, user_id, room_id) - .unwrap_or(false) - }), - ); - } - } - } - // Look for device list updates in this room - device_list_changes.extend( - services() - .users - .keys_changed(room_id.as_ref(), globalsince, None) - .filter_map(std::result::Result::ok), - ); - } - for user_id in left_encrypted_users { - let dont_share_encrypted_room = services() - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? - .filter_map(std::result::Result::ok) - .filter_map(|other_room_id| { - Some( - services() - .rooms - .state_accessor - .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") - .ok()? - .is_some(), - ) - }) - .all(|encrypted| !encrypted); - // If the user doesn't share an encrypted room with the target anymore, we need to tell - // them - if dont_share_encrypted_room { - device_list_left.insert(user_id); - } - } - } + match new_membership { + MembershipState::Join => { + // A new user joined an encrypted room + if !share_encrypted_room(&sender_user, &user_id, room_id)? { + device_list_changes.insert(user_id); + } + }, + MembershipState::Leave => { + // Write down users that have left encrypted rooms we are in + left_encrypted_users.insert(user_id); + }, + _ => {}, + } + } + } + } + } + if joined_since_last_sync || new_encrypted_room { + // If the user is in a new encrypted room, give them all joined users + device_list_changes.extend( + services() + .rooms + .state_cache + .room_members(room_id) + .flatten() + .filter(|user_id| { + // Don't send key updates from the sender to the sender + &sender_user != user_id + }) + .filter(|user_id| { + // Only send keys if the sender doesn't share an encrypted room with the target + // already + !share_encrypted_room(&sender_user, user_id, room_id).unwrap_or(false) + }), + ); + } + } + } + // Look for device list updates in this room + device_list_changes.extend( + services().users.keys_changed(room_id.as_ref(), globalsince, None).filter_map(std::result::Result::ok), + ); + } + for user_id in left_encrypted_users { + let dont_share_encrypted_room = services() + .rooms + .user + .get_shared_rooms(vec![sender_user.clone(), user_id.clone()])? + .filter_map(std::result::Result::ok) + .filter_map(|other_room_id| { + Some( + services() + .rooms + .state_accessor + .room_state_get(&other_room_id, &StateEventType::RoomEncryption, "") + .ok()? + .is_some(), + ) + }) + .all(|encrypted| !encrypted); + // If the user doesn't share an encrypted room with the target anymore, we need + // to tell them + if dont_share_encrypted_room { + device_list_left.insert(user_id); + } + } + } - let mut lists = BTreeMap::new(); - let mut todo_rooms = BTreeMap::new(); // and required state + let mut lists = BTreeMap::new(); + let mut todo_rooms = BTreeMap::new(); // and required state - for (list_id, list) in body.lists { - if list.filters.and_then(|f| f.is_invite).unwrap_or(false) { - continue; - } + for (list_id, list) in body.lists { + if list.filters.and_then(|f| f.is_invite).unwrap_or(false) { + continue; + } - let mut new_known_rooms = BTreeSet::new(); + let mut new_known_rooms = BTreeSet::new(); - lists.insert( - list_id.clone(), - sync_events::v4::SyncList { - ops: list - .ranges - .into_iter() - .map(|mut r| { - r.0 = - r.0.clamp(uint!(0), UInt::from(all_joined_rooms.len() as u32 - 1)); - r.1 = - r.1.clamp(r.0, UInt::from(all_joined_rooms.len() as u32 - 1)); - let room_ids = all_joined_rooms - [(u64::from(r.0) as usize)..=(u64::from(r.1) as usize)] - .to_vec(); - new_known_rooms.extend(room_ids.iter().cloned()); - for room_id in &room_ids { - let todo_room = todo_rooms.entry(room_id.clone()).or_insert(( - BTreeSet::new(), - 0, - u64::MAX, - )); - let limit = list - .room_details - .timeline_limit - .map_or(10, u64::from) - .min(100); - todo_room - .0 - .extend(list.room_details.required_state.iter().cloned()); - todo_room.1 = todo_room.1.max(limit); - // 0 means unknown because it got out of date - todo_room.2 = todo_room.2.min( - known_rooms - .get(&list_id) - .and_then(|k| k.get(room_id)) - .copied() - .unwrap_or(0), - ); - } - sync_events::v4::SyncOp { - op: SlidingOp::Sync, - range: Some(r), - index: None, - room_ids, - room_id: None, - } - }) - .collect(), - count: UInt::from(all_joined_rooms.len() as u32), - }, - ); + lists.insert( + list_id.clone(), + sync_events::v4::SyncList { + ops: list + .ranges + .into_iter() + .map(|mut r| { + r.0 = r.0.clamp(uint!(0), UInt::from(all_joined_rooms.len() as u32 - 1)); + r.1 = r.1.clamp(r.0, UInt::from(all_joined_rooms.len() as u32 - 1)); + let room_ids = all_joined_rooms[(u64::from(r.0) as usize)..=(u64::from(r.1) as usize)].to_vec(); + new_known_rooms.extend(room_ids.iter().cloned()); + for room_id in &room_ids { + let todo_room = todo_rooms.entry(room_id.clone()).or_insert((BTreeSet::new(), 0, u64::MAX)); + let limit = list.room_details.timeline_limit.map_or(10, u64::from).min(100); + todo_room.0.extend(list.room_details.required_state.iter().cloned()); + todo_room.1 = todo_room.1.max(limit); + // 0 means unknown because it got out of date + todo_room.2 = todo_room + .2 + .min(known_rooms.get(&list_id).and_then(|k| k.get(room_id)).copied().unwrap_or(0)); + } + sync_events::v4::SyncOp { + op: SlidingOp::Sync, + range: Some(r), + index: None, + room_ids, + room_id: None, + } + }) + .collect(), + count: UInt::from(all_joined_rooms.len() as u32), + }, + ); - if let Some(conn_id) = &body.conn_id { - services().users.update_sync_known_rooms( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - list_id, - new_known_rooms, - globalsince, - ); - } - } + if let Some(conn_id) = &body.conn_id { + services().users.update_sync_known_rooms( + sender_user.clone(), + sender_device.clone(), + conn_id.clone(), + list_id, + new_known_rooms, + globalsince, + ); + } + } - let mut known_subscription_rooms = BTreeSet::new(); - for (room_id, room) in &body.room_subscriptions { - if !services().rooms.metadata.exists(room_id)? { - continue; - } - let todo_room = todo_rooms - .entry(room_id.clone()) - .or_insert((BTreeSet::new(), 0, u64::MAX)); - let limit = room.timeline_limit.map_or(10, u64::from).min(100); - todo_room.0.extend(room.required_state.iter().cloned()); - todo_room.1 = todo_room.1.max(limit); - // 0 means unknown because it got out of date - todo_room.2 = todo_room.2.min( - known_rooms - .get("subscriptions") - .and_then(|k| k.get(room_id)) - .copied() - .unwrap_or(0), - ); - known_subscription_rooms.insert(room_id.clone()); - } + let mut known_subscription_rooms = BTreeSet::new(); + for (room_id, room) in &body.room_subscriptions { + if !services().rooms.metadata.exists(room_id)? { + continue; + } + let todo_room = todo_rooms.entry(room_id.clone()).or_insert((BTreeSet::new(), 0, u64::MAX)); + let limit = room.timeline_limit.map_or(10, u64::from).min(100); + todo_room.0.extend(room.required_state.iter().cloned()); + todo_room.1 = todo_room.1.max(limit); + // 0 means unknown because it got out of date + todo_room.2 = + todo_room.2.min(known_rooms.get("subscriptions").and_then(|k| k.get(room_id)).copied().unwrap_or(0)); + known_subscription_rooms.insert(room_id.clone()); + } - for r in body.unsubscribe_rooms { - known_subscription_rooms.remove(&r); - body.room_subscriptions.remove(&r); - } + for r in body.unsubscribe_rooms { + known_subscription_rooms.remove(&r); + body.room_subscriptions.remove(&r); + } - if let Some(conn_id) = &body.conn_id { - services().users.update_sync_known_rooms( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - "subscriptions".to_owned(), - known_subscription_rooms, - globalsince, - ); - } + if let Some(conn_id) = &body.conn_id { + services().users.update_sync_known_rooms( + sender_user.clone(), + sender_device.clone(), + conn_id.clone(), + "subscriptions".to_owned(), + known_subscription_rooms, + globalsince, + ); + } - if let Some(conn_id) = &body.conn_id { - services().users.update_sync_subscriptions( - sender_user.clone(), - sender_device.clone(), - conn_id.clone(), - body.room_subscriptions, - ); - } + if let Some(conn_id) = &body.conn_id { + services().users.update_sync_subscriptions( + sender_user.clone(), + sender_device.clone(), + conn_id.clone(), + body.room_subscriptions, + ); + } - let mut rooms = BTreeMap::new(); - for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms { - let roomsincecount = PduCount::Normal(*roomsince); + let mut rooms = BTreeMap::new(); + for (room_id, (required_state_request, timeline_limit, roomsince)) in &todo_rooms { + let roomsincecount = PduCount::Normal(*roomsince); - let (timeline_pdus, limited) = - load_timeline(&sender_user, room_id, roomsincecount, *timeline_limit)?; + let (timeline_pdus, limited) = load_timeline(&sender_user, room_id, roomsincecount, *timeline_limit)?; - if roomsince != &0 && timeline_pdus.is_empty() { - continue; - } + if roomsince != &0 && timeline_pdus.is_empty() { + continue; + } - let prev_batch = timeline_pdus - .first() - .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { - Ok(Some(match pdu_count { - PduCount::Backfilled(_) => { - error!("timeline in backfill state?!"); - "0".to_owned() - } - PduCount::Normal(c) => c.to_string(), - })) - })? - .or_else(|| { - if roomsince != &0 { - Some(roomsince.to_string()) - } else { - None - } - }); + let prev_batch = timeline_pdus + .first() + .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { + Ok(Some(match pdu_count { + PduCount::Backfilled(_) => { + error!("timeline in backfill state?!"); + "0".to_owned() + }, + PduCount::Normal(c) => c.to_string(), + })) + })? + .or_else(|| { + if roomsince != &0 { + Some(roomsince.to_string()) + } else { + None + } + }); - let room_events: Vec<_> = timeline_pdus - .iter() - .map(|(_, pdu)| pdu.to_sync_room_event()) - .collect(); + let room_events: Vec<_> = timeline_pdus.iter().map(|(_, pdu)| pdu.to_sync_room_event()).collect(); - let required_state = required_state_request - .iter() - .map(|state| { - services() - .rooms - .state_accessor - .room_state_get(room_id, &state.0, &state.1) - }) - .filter_map(std::result::Result::ok) - .flatten() - .map(|state| state.to_sync_state_event()) - .collect(); + let required_state = required_state_request + .iter() + .map(|state| services().rooms.state_accessor.room_state_get(room_id, &state.0, &state.1)) + .filter_map(std::result::Result::ok) + .flatten() + .map(|state| state.to_sync_state_event()) + .collect(); - // Heroes - let heroes = services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(std::result::Result::ok) - .filter(|member| member != &sender_user) - .map(|member| { - Ok::<_, Error>( - services() - .rooms - .state_accessor - .get_member(room_id, &member)? - .map(|memberevent| { - ( - memberevent - .displayname - .unwrap_or_else(|| member.to_string()), - memberevent.avatar_url, - ) - }), - ) - }) - .filter_map(std::result::Result::ok) - .flatten() - .take(5) - .collect::>(); - let name = match heroes.len().cmp(&(1_usize)) { - Ordering::Greater => { - let last = heroes[0].0.clone(); - Some( - heroes[1..] - .iter() - .map(|h| h.0.clone()) - .collect::>() - .join(", ") - + " and " - + &last, - ) - } - Ordering::Equal => Some(heroes[0].0.clone()), - Ordering::Less => None, - }; + // Heroes + let heroes = services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(std::result::Result::ok) + .filter(|member| member != &sender_user) + .map(|member| { + Ok::<_, Error>( + services().rooms.state_accessor.get_member(room_id, &member)?.map(|memberevent| { + ( + memberevent.displayname.unwrap_or_else(|| member.to_string()), + memberevent.avatar_url, + ) + }), + ) + }) + .filter_map(std::result::Result::ok) + .flatten() + .take(5) + .collect::>(); + let name = match heroes.len().cmp(&(1_usize)) { + Ordering::Greater => { + let last = heroes[0].0.clone(); + Some(heroes[1..].iter().map(|h| h.0.clone()).collect::>().join(", ") + " and " + &last) + }, + Ordering::Equal => Some(heroes[0].0.clone()), + Ordering::Less => None, + }; - let heroes_avatar = if heroes.len() == 1 { - heroes[0].1.clone() - } else { - None - }; + let heroes_avatar = if heroes.len() == 1 { + heroes[0].1.clone() + } else { + None + }; - rooms.insert( - room_id.clone(), - sync_events::v4::SlidingSyncRoom { - name: services().rooms.state_accessor.get_name(room_id)?.or(name), - avatar: if let Some(heroes_avatar) = heroes_avatar { - ruma::JsOption::Some(heroes_avatar) - } else { - match services().rooms.state_accessor.get_avatar(room_id)? { - ruma::JsOption::Some(avatar) => avatar - .url - .map_or(ruma::JsOption::Undefined, ruma::JsOption::Some), - ruma::JsOption::Null => ruma::JsOption::Null, - ruma::JsOption::Undefined => ruma::JsOption::Undefined, - } - }, - initial: Some(roomsince == &0), - is_dm: None, - invite_state: None, - unread_notifications: UnreadNotificationsCount { - highlight_count: Some( - services() - .rooms - .user - .highlight_count(&sender_user, room_id)? - .try_into() - .expect("notification count can't go that high"), - ), - notification_count: Some( - services() - .rooms - .user - .notification_count(&sender_user, room_id)? - .try_into() - .expect("notification count can't go that high"), - ), - }, - timeline: room_events, - required_state, - prev_batch, - limited, - joined_count: Some( - (services() - .rooms - .state_cache - .room_joined_count(room_id)? - .unwrap_or(0) as u32) - .into(), - ), - invited_count: Some( - (services() - .rooms - .state_cache - .room_invited_count(room_id)? - .unwrap_or(0) as u32) - .into(), - ), - num_live: None, // Count events in timeline greater than global sync counter - timestamp: None, - }, - ); - } + rooms.insert( + room_id.clone(), + sync_events::v4::SlidingSyncRoom { + name: services().rooms.state_accessor.get_name(room_id)?.or(name), + avatar: if let Some(heroes_avatar) = heroes_avatar { + ruma::JsOption::Some(heroes_avatar) + } else { + match services().rooms.state_accessor.get_avatar(room_id)? { + ruma::JsOption::Some(avatar) => { + avatar.url.map_or(ruma::JsOption::Undefined, ruma::JsOption::Some) + }, + ruma::JsOption::Null => ruma::JsOption::Null, + ruma::JsOption::Undefined => ruma::JsOption::Undefined, + } + }, + initial: Some(roomsince == &0), + is_dm: None, + invite_state: None, + unread_notifications: UnreadNotificationsCount { + highlight_count: Some( + services() + .rooms + .user + .highlight_count(&sender_user, room_id)? + .try_into() + .expect("notification count can't go that high"), + ), + notification_count: Some( + services() + .rooms + .user + .notification_count(&sender_user, room_id)? + .try_into() + .expect("notification count can't go that high"), + ), + }, + timeline: room_events, + required_state, + prev_batch, + limited, + joined_count: Some( + (services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(0) as u32).into(), + ), + invited_count: Some( + (services().rooms.state_cache.room_invited_count(room_id)?.unwrap_or(0) as u32).into(), + ), + num_live: None, // Count events in timeline greater than global sync counter + timestamp: None, + }, + ); + } - if rooms - .iter() - .all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty()) - { - // Hang a few seconds so requests are not spammed - // Stop hanging if new info arrives - let mut duration = body.timeout.unwrap_or(Duration::from_secs(30)); - if duration.as_secs() > 30 { - duration = Duration::from_secs(30); - } - let _ = tokio::time::timeout(duration, watcher).await; - } + if rooms.iter().all(|(_, r)| r.timeline.is_empty() && r.required_state.is_empty()) { + // Hang a few seconds so requests are not spammed + // Stop hanging if new info arrives + let mut duration = body.timeout.unwrap_or(Duration::from_secs(30)); + if duration.as_secs() > 30 { + duration = Duration::from_secs(30); + } + let _ = tokio::time::timeout(duration, watcher).await; + } - Ok(sync_events::v4::Response { - initial: globalsince == 0, - txn_id: body.txn_id.clone(), - pos: next_batch.to_string(), - lists, - rooms, - extensions: sync_events::v4::Extensions { - to_device: if body.extensions.to_device.enabled.unwrap_or(false) { - Some(sync_events::v4::ToDevice { - events: services() - .users - .get_to_device_events(&sender_user, &sender_device)?, - next_batch: next_batch.to_string(), - }) - } else { - None - }, - e2ee: sync_events::v4::E2EE { - device_lists: DeviceLists { - changed: device_list_changes.into_iter().collect(), - left: device_list_left.into_iter().collect(), - }, - device_one_time_keys_count: services() - .users - .count_one_time_keys(&sender_user, &sender_device)?, - // Fallback keys are not yet supported - device_unused_fallback_key_types: None, - }, - account_data: sync_events::v4::AccountData { - global: if body.extensions.account_data.enabled.unwrap_or(false) { - services() - .account_data - .changes_since(None, &sender_user, globalsince)? - .into_iter() - .filter_map(|(_, v)| { - serde_json::from_str(v.json().get()) - .map_err(|_| { - Error::bad_database("Invalid account event in database.") - }) - .ok() - }) - .collect() - } else { - Vec::new() - }, - rooms: BTreeMap::new(), - }, - receipts: sync_events::v4::Receipts { - rooms: BTreeMap::new(), - }, - typing: sync_events::v4::Typing { - rooms: BTreeMap::new(), - }, - }, - delta_token: None, - }) + Ok(sync_events::v4::Response { + initial: globalsince == 0, + txn_id: body.txn_id.clone(), + pos: next_batch.to_string(), + lists, + rooms, + extensions: sync_events::v4::Extensions { + to_device: if body.extensions.to_device.enabled.unwrap_or(false) { + Some(sync_events::v4::ToDevice { + events: services().users.get_to_device_events(&sender_user, &sender_device)?, + next_batch: next_batch.to_string(), + }) + } else { + None + }, + e2ee: sync_events::v4::E2EE { + device_lists: DeviceLists { + changed: device_list_changes.into_iter().collect(), + left: device_list_left.into_iter().collect(), + }, + device_one_time_keys_count: services().users.count_one_time_keys(&sender_user, &sender_device)?, + // Fallback keys are not yet supported + device_unused_fallback_key_types: None, + }, + account_data: sync_events::v4::AccountData { + global: if body.extensions.account_data.enabled.unwrap_or(false) { + services() + .account_data + .changes_since(None, &sender_user, globalsince)? + .into_iter() + .filter_map(|(_, v)| { + serde_json::from_str(v.json().get()) + .map_err(|_| Error::bad_database("Invalid account event in database.")) + .ok() + }) + .collect() + } else { + Vec::new() + }, + rooms: BTreeMap::new(), + }, + receipts: sync_events::v4::Receipts { + rooms: BTreeMap::new(), + }, + typing: sync_events::v4::Typing { + rooms: BTreeMap::new(), + }, + }, + delta_token: None, + }) } diff --git a/src/api/client_server/tag.rs b/src/api/client_server/tag.rs index 16f1600f..381596fd 100644 --- a/src/api/client_server/tag.rs +++ b/src/api/client_server/tag.rs @@ -1,55 +1,45 @@ -use crate::{services, Error, Result, Ruma}; -use ruma::{ - api::client::tag::{create_tag, delete_tag, get_tags}, - events::{ - tag::{TagEvent, TagEventContent}, - RoomAccountDataEventType, - }, -}; use std::collections::BTreeMap; +use ruma::{ + api::client::tag::{create_tag, delete_tag, get_tags}, + events::{ + tag::{TagEvent, TagEventContent}, + RoomAccountDataEventType, + }, +}; + +use crate::{services, Error, Result, Ruma}; + /// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` /// /// Adds a tag to the room. /// /// - Inserts the tag into the tag event of the room account data. -pub async fn update_tag_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn update_tag_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services().account_data.get( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - )?; + let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; - let mut tags_event = event - .map(|e| { - serde_json::from_str(e.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db.")) - }) - .unwrap_or_else(|| { - Ok(TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - })?; + let mut tags_event = event + .map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db."))) + .unwrap_or_else(|| { + Ok(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }) + })?; - tags_event - .content - .tags - .insert(body.tag.clone().into(), body.tag_info.clone()); + tags_event.content.tags.insert(body.tag.clone().into(), body.tag_info.clone()); - services().account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + services().account_data.update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + )?; - Ok(create_tag::v3::Response {}) + Ok(create_tag::v3::Response {}) } /// # `DELETE /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}` @@ -57,40 +47,31 @@ pub async fn update_tag_route( /// Deletes a tag from the room. /// /// - Removes the tag from the tag event of the room account data. -pub async fn delete_tag_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn delete_tag_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services().account_data.get( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - )?; + let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; - let mut tags_event = event - .map(|e| { - serde_json::from_str(e.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db.")) - }) - .unwrap_or_else(|| { - Ok(TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - })?; + let mut tags_event = event + .map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db."))) + .unwrap_or_else(|| { + Ok(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }) + })?; - tags_event.content.tags.remove(&body.tag.clone().into()); + tags_event.content.tags.remove(&body.tag.clone().into()); - services().account_data.update( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - &serde_json::to_value(tags_event).expect("to json value always works"), - )?; + services().account_data.update( + Some(&body.room_id), + sender_user, + RoomAccountDataEventType::Tag, + &serde_json::to_value(tags_event).expect("to json value always works"), + )?; - Ok(delete_tag::v3::Response {}) + Ok(delete_tag::v3::Response {}) } /// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags` @@ -99,28 +80,21 @@ pub async fn delete_tag_route( /// /// - Gets the tag event of the room account data. pub async fn get_tags_route(body: Ruma) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let event = services().account_data.get( - Some(&body.room_id), - sender_user, - RoomAccountDataEventType::Tag, - )?; + let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?; - let tags_event = event - .map(|e| { - serde_json::from_str(e.get()) - .map_err(|_| Error::bad_database("Invalid account data event in db.")) - }) - .unwrap_or_else(|| { - Ok(TagEvent { - content: TagEventContent { - tags: BTreeMap::new(), - }, - }) - })?; + let tags_event = event + .map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db."))) + .unwrap_or_else(|| { + Ok(TagEvent { + content: TagEventContent { + tags: BTreeMap::new(), + }, + }) + })?; - Ok(get_tags::v3::Response { - tags: tags_event.content.tags, - }) + Ok(get_tags::v3::Response { + tags: tags_event.content.tags, + }) } diff --git a/src/api/client_server/thirdparty.rs b/src/api/client_server/thirdparty.rs index c2c1adfd..f5de4c61 100644 --- a/src/api/client_server/thirdparty.rs +++ b/src/api/client_server/thirdparty.rs @@ -1,16 +1,15 @@ -use crate::{Result, Ruma}; +use std::collections::BTreeMap; + use ruma::api::client::thirdparty::get_protocols; -use std::collections::BTreeMap; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/thirdparty/protocols` /// /// TODO: Fetches all metadata about protocols supported by the homeserver. -pub async fn get_protocols_route( - _body: Ruma, -) -> Result { - // TODO - Ok(get_protocols::v3::Response { - protocols: BTreeMap::new(), - }) +pub async fn get_protocols_route(_body: Ruma) -> Result { + // TODO + Ok(get_protocols::v3::Response { + protocols: BTreeMap::new(), + }) } diff --git a/src/api/client_server/threads.rs b/src/api/client_server/threads.rs index e1731060..a154a24c 100644 --- a/src/api/client_server/threads.rs +++ b/src/api/client_server/threads.rs @@ -3,47 +3,37 @@ use ruma::api::client::{error::ErrorKind, threads::get_threads}; use crate::{services, Error, Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/threads` -pub async fn get_threads_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub async fn get_threads_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - // Use limit or else 10, with maximum 100 - let limit = body - .limit - .and_then(|l| l.try_into().ok()) - .unwrap_or(10) - .min(100); + // Use limit or else 10, with maximum 100 + let limit = body.limit.and_then(|l| l.try_into().ok()).unwrap_or(10).min(100); - let from = if let Some(from) = &body.from { - from.parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))? - } else { - u64::MAX - }; + let from = if let Some(from) = &body.from { + from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))? + } else { + u64::MAX + }; - let threads = services() - .rooms - .threads - .threads_until(sender_user, &body.room_id, from, &body.include)? - .take(limit) - .filter_map(std::result::Result::ok) - .filter(|(_, pdu)| { - services() - .rooms - .state_accessor - .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) - .unwrap_or(false) - }) - .collect::>(); + let threads = services() + .rooms + .threads + .threads_until(sender_user, &body.room_id, from, &body.include)? + .take(limit) + .filter_map(std::result::Result::ok) + .filter(|(_, pdu)| { + services() + .rooms + .state_accessor + .user_can_see_event(sender_user, &body.room_id, &pdu.event_id) + .unwrap_or(false) + }) + .collect::>(); - let next_batch = threads.last().map(|(count, _)| count.to_string()); + let next_batch = threads.last().map(|(count, _)| count.to_string()); - Ok(get_threads::v1::Response { - chunk: threads - .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) - .collect(), - next_batch, - }) + Ok(get_threads::v1::Response { + chunk: threads.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(), + next_batch, + }) } diff --git a/src/api/client_server/to_device.rs b/src/api/client_server/to_device.rs index bce893ed..7e97f61e 100644 --- a/src/api/client_server/to_device.rs +++ b/src/api/client_server/to_device.rs @@ -1,92 +1,85 @@ use std::collections::BTreeMap; -use crate::{services, Error, Result, Ruma}; use ruma::{ - api::{ - client::{error::ErrorKind, to_device::send_event_to_device}, - federation::{self, transactions::edu::DirectDeviceContent}, - }, - to_device::DeviceIdOrAllDevices, + api::{ + client::{error::ErrorKind, to_device::send_event_to_device}, + federation::{self, transactions::edu::DirectDeviceContent}, + }, + to_device::DeviceIdOrAllDevices, }; +use crate::{services, Error, Result, Ruma}; + /// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}` /// /// Send a to-device event to a set of client devices. pub async fn send_event_to_device_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let sender_device = body.sender_device.as_deref(); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_device = body.sender_device.as_deref(); - // Check if this is a new transaction id - if services() - .transaction_ids - .existing_txnid(sender_user, sender_device, &body.txn_id)? - .is_some() - { - return Ok(send_event_to_device::v3::Response {}); - } + // Check if this is a new transaction id + if services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)?.is_some() { + return Ok(send_event_to_device::v3::Response {}); + } - for (target_user_id, map) in &body.messages { - for (target_device_id_maybe, event) in map { - if target_user_id.server_name() != services().globals.server_name() { - let mut map = BTreeMap::new(); - map.insert(target_device_id_maybe.clone(), event.clone()); - let mut messages = BTreeMap::new(); - messages.insert(target_user_id.clone(), map); - let count = services().globals.next_count()?; + for (target_user_id, map) in &body.messages { + for (target_device_id_maybe, event) in map { + if target_user_id.server_name() != services().globals.server_name() { + let mut map = BTreeMap::new(); + map.insert(target_device_id_maybe.clone(), event.clone()); + let mut messages = BTreeMap::new(); + messages.insert(target_user_id.clone(), map); + let count = services().globals.next_count()?; - services().sending.send_reliable_edu( - target_user_id.server_name(), - serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice( - DirectDeviceContent { - sender: sender_user.clone(), - ev_type: body.event_type.clone(), - message_id: count.to_string().into(), - messages, - }, - )) - .expect("DirectToDevice EDU can be serialized"), - count, - )?; + services().sending.send_reliable_edu( + target_user_id.server_name(), + serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent { + sender: sender_user.clone(), + ev_type: body.event_type.clone(), + message_id: count.to_string().into(), + messages, + })) + .expect("DirectToDevice EDU can be serialized"), + count, + )?; - continue; - } + continue; + } - match target_device_id_maybe { - DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services().users.add_to_device_event( - sender_user, - target_user_id, - target_device_id, - &body.event_type.to_string(), - event.deserialize_as().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") - })?, - )?; - } + match target_device_id_maybe { + DeviceIdOrAllDevices::DeviceId(target_device_id) => { + services().users.add_to_device_event( + sender_user, + target_user_id, + target_device_id, + &body.event_type.to_string(), + event + .deserialize_as() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, + )?; + }, - DeviceIdOrAllDevices::AllDevices => { - for target_device_id in services().users.all_device_ids(target_user_id) { - services().users.add_to_device_event( - sender_user, - target_user_id, - &target_device_id?, - &body.event_type.to_string(), - event.deserialize_as().map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") - })?, - )?; - } - } - } - } - } + DeviceIdOrAllDevices::AllDevices => { + for target_device_id in services().users.all_device_ids(target_user_id) { + services().users.add_to_device_event( + sender_user, + target_user_id, + &target_device_id?, + &body.event_type.to_string(), + event + .deserialize_as() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?, + )?; + } + }, + } + } + } - // Save transaction id with empty data - services() - .transaction_ids - .add_txnid(sender_user, sender_device, &body.txn_id, &[])?; + // Save transaction id with empty data + services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, &[])?; - Ok(send_event_to_device::v3::Response {}) + Ok(send_event_to_device::v3::Response {}) } diff --git a/src/api/client_server/typing.rs b/src/api/client_server/typing.rs index 43217e1a..a51c3b88 100644 --- a/src/api/client_server/typing.rs +++ b/src/api/client_server/typing.rs @@ -1,40 +1,30 @@ -use crate::{services, utils, Error, Result, Ruma}; use ruma::api::client::{error::ErrorKind, typing::create_typing_event}; +use crate::{services, utils, Error, Result, Ruma}; + /// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}` /// /// Sets the typing state of the sender user. pub async fn create_typing_event_route( - body: Ruma, + body: Ruma, ) -> Result { - use create_typing_event::v3::Typing; + use create_typing_event::v3::Typing; - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - if !services() - .rooms - .state_cache - .is_joined(sender_user, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "You are not in this room.", - )); - } + if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? { + return Err(Error::BadRequest(ErrorKind::Forbidden, "You are not in this room.")); + } - if let Typing::Yes(duration) = body.state { - services().rooms.edus.typing.typing_add( - sender_user, - &body.room_id, - duration.as_millis() as u64 + utils::millis_since_unix_epoch(), - )?; - } else { - services() - .rooms - .edus - .typing - .typing_remove(sender_user, &body.room_id)?; - } + if let Typing::Yes(duration) = body.state { + services().rooms.edus.typing.typing_add( + sender_user, + &body.room_id, + duration.as_millis() as u64 + utils::millis_since_unix_epoch(), + )?; + } else { + services().rooms.edus.typing.typing_remove(sender_user, &body.room_id)?; + } - Ok(create_typing_event::v3::Response {}) + Ok(create_typing_event::v3::Response {}) } diff --git a/src/api/client_server/unversioned.rs b/src/api/client_server/unversioned.rs index e092eb05..af911fd9 100644 --- a/src/api/client_server/unversioned.rs +++ b/src/api/client_server/unversioned.rs @@ -7,72 +7,74 @@ use crate::{services, Error, Result, Ruma}; /// # `GET /_matrix/client/versions` /// -/// Get the versions of the specification and unstable features supported by this server. +/// Get the versions of the specification and unstable features supported by +/// this server. /// /// - Versions take the form MAJOR.MINOR.PATCH /// - Only the latest PATCH release will be reported for each MAJOR.MINOR value -/// - Unstable features are namespaced and may include version information in their name +/// - Unstable features are namespaced and may include version information in +/// their name /// -/// Note: Unstable features are used while developing new features. Clients should avoid using -/// unstable features in their stable releases +/// Note: Unstable features are used while developing new features. Clients +/// should avoid using unstable features in their stable releases pub async fn get_supported_versions_route( - _body: Ruma, + _body: Ruma, ) -> Result { - let resp = get_supported_versions::Response { - versions: vec![ - "r0.0.1".to_owned(), - "r0.1.0".to_owned(), - "r0.2.0".to_owned(), - "r0.3.0".to_owned(), - "r0.4.0".to_owned(), - "r0.5.0".to_owned(), - "r0.6.0".to_owned(), - "r0.6.1".to_owned(), - "v1.1".to_owned(), - "v1.2".to_owned(), - "v1.3".to_owned(), - "v1.4".to_owned(), - "v1.5".to_owned(), - ], - unstable_features: BTreeMap::from_iter([ - ("org.matrix.e2e_cross_signing".to_owned(), true), - ("org.matrix.msc2836".to_owned(), true), - ("org.matrix.msc3827".to_owned(), true), - ("org.matrix.msc2946".to_owned(), true), - ]), - }; + let resp = get_supported_versions::Response { + versions: vec![ + "r0.0.1".to_owned(), + "r0.1.0".to_owned(), + "r0.2.0".to_owned(), + "r0.3.0".to_owned(), + "r0.4.0".to_owned(), + "r0.5.0".to_owned(), + "r0.6.0".to_owned(), + "r0.6.1".to_owned(), + "v1.1".to_owned(), + "v1.2".to_owned(), + "v1.3".to_owned(), + "v1.4".to_owned(), + "v1.5".to_owned(), + ], + unstable_features: BTreeMap::from_iter([ + ("org.matrix.e2e_cross_signing".to_owned(), true), + ("org.matrix.msc2836".to_owned(), true), + ("org.matrix.msc3827".to_owned(), true), + ("org.matrix.msc2946".to_owned(), true), + ]), + }; - Ok(resp) + Ok(resp) } /// # `GET /.well-known/matrix/client` pub async fn well_known_client_route() -> Result { - let client_url = match services().globals.well_known_client() { - Some(url) => url.clone(), - None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), - }; + let client_url = match services().globals.well_known_client() { + Some(url) => url.clone(), + None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), + }; - Ok(Json(serde_json::json!({ - "m.homeserver": {"base_url": client_url}, - "org.matrix.msc3575.proxy": {"url": client_url} - }))) + Ok(Json(serde_json::json!({ + "m.homeserver": {"base_url": client_url}, + "org.matrix.msc3575.proxy": {"url": client_url} + }))) } /// # `GET /client/server.json` /// -/// Endpoint provided by sliding sync proxy used by some clients such as Element Web -/// as a non-standard health check. +/// Endpoint provided by sliding sync proxy used by some clients such as Element +/// Web as a non-standard health check. pub async fn syncv3_client_server_json() -> Result { - let server_url = match services().globals.well_known_client() { - Some(url) => url.clone(), - None => match services().globals.well_known_server() { - Some(url) => url.clone(), - None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), - }, - }; + let server_url = match services().globals.well_known_client() { + Some(url) => url.clone(), + None => match services().globals.well_known_server() { + Some(url) => url.clone(), + None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), + }, + }; - Ok(Json(serde_json::json!({ - "server": server_url, - "version": format!("{} {}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")) - }))) + Ok(Json(serde_json::json!({ + "server": server_url, + "version": format!("{} {}", env!("CARGO_PKG_NAME"), env!("CARGO_PKG_VERSION")) + }))) } diff --git a/src/api/client_server/user_directory.rs b/src/api/client_server/user_directory.rs index f648cb7b..b71c42e0 100644 --- a/src/api/client_server/user_directory.rs +++ b/src/api/client_server/user_directory.rs @@ -1,94 +1,78 @@ -use crate::{services, Result, Ruma}; use ruma::{ - api::client::user_directory::search_users, - events::{ - room::join_rules::{JoinRule, RoomJoinRulesEventContent}, - StateEventType, - }, + api::client::user_directory::search_users, + events::{ + room::join_rules::{JoinRule, RoomJoinRulesEventContent}, + StateEventType, + }, }; +use crate::{services, Result, Ruma}; + /// # `POST /_matrix/client/r0/user_directory/search` /// /// Searches all known users for a match. /// -/// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public) +/// - Hides any local users that aren't in any public rooms (i.e. those that +/// have the join rule set to public) /// and don't share a room with the sender -pub async fn search_users_route( - body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let limit = u64::from(body.limit) as usize; +pub async fn search_users_route(body: Ruma) -> Result { + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let limit = u64::from(body.limit) as usize; - let mut users = services().users.iter().filter_map(|user_id| { - // Filter out buggy users (they should not exist, but you never know...) - let user_id = user_id.ok()?; + let mut users = services().users.iter().filter_map(|user_id| { + // Filter out buggy users (they should not exist, but you never know...) + let user_id = user_id.ok()?; - let user = search_users::v3::User { - user_id: user_id.clone(), - display_name: services().users.displayname(&user_id).ok()?, - avatar_url: services().users.avatar_url(&user_id).ok()?, - }; + let user = search_users::v3::User { + user_id: user_id.clone(), + display_name: services().users.displayname(&user_id).ok()?, + avatar_url: services().users.avatar_url(&user_id).ok()?, + }; - let user_id_matches = user - .user_id - .to_string() - .to_lowercase() - .contains(&body.search_term.to_lowercase()); + let user_id_matches = user.user_id.to_string().to_lowercase().contains(&body.search_term.to_lowercase()); - let user_displayname_matches = user - .display_name - .as_ref() - .filter(|name| { - name.to_lowercase() - .contains(&body.search_term.to_lowercase()) - }) - .is_some(); + let user_displayname_matches = user + .display_name + .as_ref() + .filter(|name| name.to_lowercase().contains(&body.search_term.to_lowercase())) + .is_some(); - if !user_id_matches && !user_displayname_matches { - return None; - } + if !user_id_matches && !user_displayname_matches { + return None; + } - let user_is_in_public_rooms = services() - .rooms - .state_cache - .rooms_joined(&user_id) - .filter_map(std::result::Result::ok) - .any(|room| { - services() - .rooms - .state_accessor - .room_state_get(&room, &StateEventType::RoomJoinRules, "") - .map_or(false, |event| { - event.map_or(false, |event| { - serde_json::from_str(event.content.get()) - .map_or(false, |r: RoomJoinRulesEventContent| { - r.join_rule == JoinRule::Public - }) - }) - }) - }); + let user_is_in_public_rooms = + services().rooms.state_cache.rooms_joined(&user_id).filter_map(std::result::Result::ok).any(|room| { + services().rooms.state_accessor.room_state_get(&room, &StateEventType::RoomJoinRules, "").map_or( + false, + |event| { + event.map_or(false, |event| { + serde_json::from_str(event.content.get()) + .map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public) + }) + }, + ) + }); - if user_is_in_public_rooms { - return Some(user); - } + if user_is_in_public_rooms { + return Some(user); + } - let user_is_in_shared_rooms = services() - .rooms - .user - .get_shared_rooms(vec![sender_user.clone(), user_id]) - .ok()? - .next() - .is_some(); + let user_is_in_shared_rooms = + services().rooms.user.get_shared_rooms(vec![sender_user.clone(), user_id]).ok()?.next().is_some(); - if user_is_in_shared_rooms { - return Some(user); - } + if user_is_in_shared_rooms { + return Some(user); + } - None - }); + None + }); - let results = users.by_ref().take(limit).collect(); - let limited = users.next().is_some(); + let results = users.by_ref().take(limit).collect(); + let limited = users.next().is_some(); - Ok(search_users::v3::Response { results, limited }) + Ok(search_users::v3::Response { + results, + limited, + }) } diff --git a/src/api/client_server/voip.rs b/src/api/client_server/voip.rs index f0d91f71..5bd10ea3 100644 --- a/src/api/client_server/voip.rs +++ b/src/api/client_server/voip.rs @@ -1,9 +1,11 @@ -use crate::{services, Result, Ruma}; +use std::time::{Duration, SystemTime}; + use base64::{engine::general_purpose, Engine as _}; use hmac::{Hmac, Mac}; use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch}; use sha1::Sha1; -use std::time::{Duration, SystemTime}; + +use crate::{services, Result, Ruma}; type HmacSha1 = Hmac; @@ -11,38 +13,37 @@ type HmacSha1 = Hmac; /// /// TODO: Returns information about the recommended turn server. pub async fn turn_server_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let turn_secret = services().globals.turn_secret().clone(); + let turn_secret = services().globals.turn_secret().clone(); - let (username, password) = if !turn_secret.is_empty() { - let expiry = SecondsSinceUnixEpoch::from_system_time( - SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()), - ) - .expect("time is valid"); + let (username, password) = if !turn_secret.is_empty() { + let expiry = SecondsSinceUnixEpoch::from_system_time( + SystemTime::now() + Duration::from_secs(services().globals.turn_ttl()), + ) + .expect("time is valid"); - let username: String = format!("{}:{}", expiry.get(), sender_user); + let username: String = format!("{}:{}", expiry.get(), sender_user); - let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes()) - .expect("HMAC can take key of any size"); - mac.update(username.as_bytes()); + let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes()).expect("HMAC can take key of any size"); + mac.update(username.as_bytes()); - let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes()); + let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes()); - (username, password) - } else { - ( - services().globals.turn_username().clone(), - services().globals.turn_password().clone(), - ) - }; + (username, password) + } else { + ( + services().globals.turn_username().clone(), + services().globals.turn_password().clone(), + ) + }; - Ok(get_turn_server_info::v3::Response { - username, - password, - uris: services().globals.turn_uris().to_vec(), - ttl: Duration::from_secs(services().globals.turn_ttl()), - }) + Ok(get_turn_server_info::v3::Response { + username, + password, + uris: services().globals.turn_uris().to_vec(), + ttl: Duration::from_secs(services().globals.turn_ttl()), + }) } diff --git a/src/api/ruma_wrapper/axum.rs b/src/api/ruma_wrapper/axum.rs index 1117874c..36a7ba87 100644 --- a/src/api/ruma_wrapper/axum.rs +++ b/src/api/ruma_wrapper/axum.rs @@ -1,21 +1,21 @@ use std::{collections::BTreeMap, str}; use axum::{ - async_trait, - body::{Full, HttpBody}, - extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader}, - headers::{ - authorization::{Bearer, Credentials}, - Authorization, - }, - response::{IntoResponse, Response}, - BoxError, RequestExt, RequestPartsExt, + async_trait, + body::{Full, HttpBody}, + extract::{rejection::TypedHeaderRejectionReason, FromRequest, Path, TypedHeader}, + headers::{ + authorization::{Bearer, Credentials}, + Authorization, + }, + response::{IntoResponse, Response}, + BoxError, RequestExt, RequestPartsExt, }; use bytes::{Buf, BufMut, Bytes, BytesMut}; use http::{Request, StatusCode}; use ruma::{ - api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, - CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId, + api::{client::error::ErrorKind, AuthScheme, IncomingRequest, OutgoingResponse}, + CanonicalJsonValue, OwnedDeviceId, OwnedServerName, UserId, }; use serde::Deserialize; use tracing::{debug, error, warn}; @@ -25,400 +25,333 @@ use crate::{services, Error, Result}; #[derive(Deserialize)] struct QueryParams { - access_token: Option, - user_id: Option, + access_token: Option, + user_id: Option, } #[async_trait] impl FromRequest for Ruma where - T: IncomingRequest, - B: HttpBody + Send + 'static, - B::Data: Send, - B::Error: Into, + T: IncomingRequest, + B: HttpBody + Send + 'static, + B::Data: Send, + B::Error: Into, { - type Rejection = Error; + type Rejection = Error; - async fn from_request(req: Request, _state: &S) -> Result { - let (mut parts, mut body) = match req.with_limited_body() { - Ok(limited_req) => { - let (parts, body) = limited_req.into_parts(); - let body = to_bytes(body) - .await - .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; - (parts, body) - } - Err(original_req) => { - let (parts, body) = original_req.into_parts(); - let body = to_bytes(body) - .await - .map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; - (parts, body) - } - }; + async fn from_request(req: Request, _state: &S) -> Result { + let (mut parts, mut body) = match req.with_limited_body() { + Ok(limited_req) => { + let (parts, body) = limited_req.into_parts(); + let body = + to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + (parts, body) + }, + Err(original_req) => { + let (parts, body) = original_req.into_parts(); + let body = + to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?; + (parts, body) + }, + }; - let metadata = T::METADATA; - let auth_header: Option>> = parts.extract().await?; - let path_params: Path> = parts.extract().await?; + let metadata = T::METADATA; + let auth_header: Option>> = parts.extract().await?; + let path_params: Path> = parts.extract().await?; - let query = parts.uri.query().unwrap_or_default(); - let query_params: QueryParams = match serde_html_form::from_str(query) { - Ok(params) => params, - Err(e) => { - error!(%query, "Failed to deserialize query parameters: {}", e); - return Err(Error::BadRequest( - ErrorKind::Unknown, - "Failed to read query parameters", - )); - } - }; + let query = parts.uri.query().unwrap_or_default(); + let query_params: QueryParams = match serde_html_form::from_str(query) { + Ok(params) => params, + Err(e) => { + error!(%query, "Failed to deserialize query parameters: {}", e); + return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters")); + }, + }; - let token = match &auth_header { - Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()), - None => query_params.access_token.as_deref(), - }; + let token = match &auth_header { + Some(TypedHeader(Authorization(bearer))) => Some(bearer.token()), + None => query_params.access_token.as_deref(), + }; - let mut json_body = serde_json::from_slice::(&body).ok(); + let mut json_body = serde_json::from_slice::(&body).ok(); - let appservices = services().appservice.all().unwrap(); - let appservice_registration = appservices - .iter() - .find(|(_id, registration)| Some(registration.as_token.as_str()) == token); + let appservices = services().appservice.all().unwrap(); + let appservice_registration = + appservices.iter().find(|(_id, registration)| Some(registration.as_token.as_str()) == token); - let (sender_user, sender_device, sender_servername, from_appservice) = - if let Some((_id, registration)) = appservice_registration { - match metadata.authentication { - AuthScheme::AccessToken => { - let user_id = query_params.user_id.map_or_else( - || { - UserId::parse_with_server_name( - registration.sender_localpart.as_str(), - services().globals.server_name(), - ) - .unwrap() - }, - |s| UserId::parse(s).unwrap(), - ); + let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) = + appservice_registration + { + match metadata.authentication { + AuthScheme::AccessToken => { + let user_id = query_params.user_id.map_or_else( + || { + UserId::parse_with_server_name( + registration.sender_localpart.as_str(), + services().globals.server_name(), + ) + .unwrap() + }, + |s| UserId::parse(s).unwrap(), + ); - if !services().users.exists(&user_id).unwrap() { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "User does not exist.", - )); - } + if !services().users.exists(&user_id).unwrap() { + return Err(Error::BadRequest(ErrorKind::Forbidden, "User does not exist.")); + } - // TODO: Check if appservice is allowed to be that user - (Some(user_id), None, None, true) - } - AuthScheme::ServerSignatures => (None, None, None, true), - AuthScheme::None => (None, None, None, true), - } - } else { - match metadata.authentication { - AuthScheme::AccessToken => { - let token = match token { - Some(token) => token, - _ => { - return Err(Error::BadRequest( - ErrorKind::MissingToken, - "Missing access token.", - )) - } - }; + // TODO: Check if appservice is allowed to be that user + (Some(user_id), None, None, true) + }, + AuthScheme::ServerSignatures => (None, None, None, true), + AuthScheme::None => (None, None, None, true), + } + } else { + match metadata.authentication { + AuthScheme::AccessToken => { + let token = match token { + Some(token) => token, + _ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")), + }; - match services().users.find_from_token(token).unwrap() { - None => { - return Err(Error::BadRequest( - ErrorKind::UnknownToken { soft_logout: false }, - "Unknown access token.", - )) - } - Some((user_id, device_id)) => ( - Some(user_id), - Some(OwnedDeviceId::from(device_id)), - None, - false, - ), - } - } - AuthScheme::ServerSignatures => { - let TypedHeader(Authorization(x_matrix)) = parts - .extract::>>() - .await - .map_err(|e| { - warn!("Missing or invalid Authorization header: {}", e); + match services().users.find_from_token(token).unwrap() { + None => { + return Err(Error::BadRequest( + ErrorKind::UnknownToken { + soft_logout: false, + }, + "Unknown access token.", + )) + }, + Some((user_id, device_id)) => { + (Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false) + }, + } + }, + AuthScheme::ServerSignatures => { + let TypedHeader(Authorization(x_matrix)) = + parts.extract::>>().await.map_err(|e| { + warn!("Missing or invalid Authorization header: {}", e); - let msg = match e.reason() { - TypedHeaderRejectionReason::Missing => { - "Missing Authorization header." - } - TypedHeaderRejectionReason::Error(_) => { - "Invalid X-Matrix signatures." - } - _ => "Unknown header-related error", - }; + let msg = match e.reason() { + TypedHeaderRejectionReason::Missing => "Missing Authorization header.", + TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.", + _ => "Unknown header-related error", + }; - Error::BadRequest(ErrorKind::Forbidden, msg) - })?; + Error::BadRequest(ErrorKind::Forbidden, msg) + })?; - let origin_signatures = BTreeMap::from_iter([( - x_matrix.key.clone(), - CanonicalJsonValue::String(x_matrix.sig), - )]); + let origin_signatures = + BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig))]); - let signatures = BTreeMap::from_iter([( - x_matrix.origin.as_str().to_owned(), - CanonicalJsonValue::Object(origin_signatures), - )]); + let signatures = BTreeMap::from_iter([( + x_matrix.origin.as_str().to_owned(), + CanonicalJsonValue::Object(origin_signatures), + )]); - let server_destination = - services().globals.server_name().as_str().to_owned(); + let server_destination = services().globals.server_name().as_str().to_owned(); - if let Some(destination) = x_matrix.destination.as_ref() { - if destination != &server_destination { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Invalid authorization.", - )); - } - } + if let Some(destination) = x_matrix.destination.as_ref() { + if destination != &server_destination { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Invalid authorization.")); + } + } - let mut request_map = BTreeMap::from_iter([ - ( - "method".to_owned(), - CanonicalJsonValue::String(parts.method.to_string()), - ), - ( - "uri".to_owned(), - CanonicalJsonValue::String(parts.uri.to_string()), - ), - ( - "origin".to_owned(), - CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), - ), - ( - "destination".to_owned(), - CanonicalJsonValue::String(server_destination), - ), - ( - "signatures".to_owned(), - CanonicalJsonValue::Object(signatures), - ), - ]); + let mut request_map = BTreeMap::from_iter([ + ("method".to_owned(), CanonicalJsonValue::String(parts.method.to_string())), + ("uri".to_owned(), CanonicalJsonValue::String(parts.uri.to_string())), + ( + "origin".to_owned(), + CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()), + ), + ("destination".to_owned(), CanonicalJsonValue::String(server_destination)), + ("signatures".to_owned(), CanonicalJsonValue::Object(signatures)), + ]); - if let Some(json_body) = &json_body { - request_map.insert("content".to_owned(), json_body.clone()); - }; + if let Some(json_body) = &json_body { + request_map.insert("content".to_owned(), json_body.clone()); + }; - let keys_result = services() - .rooms - .event_handler - .fetch_signing_keys_for_server( - &x_matrix.origin, - vec![x_matrix.key.clone()], - ) - .await; + let keys_result = services() + .rooms + .event_handler + .fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.clone()]) + .await; - let keys = match keys_result { - Ok(b) => b, - Err(e) => { - warn!("Failed to fetch signing keys: {}", e); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Failed to fetch signing keys.", - )); - } - }; + let keys = match keys_result { + Ok(b) => b, + Err(e) => { + warn!("Failed to fetch signing keys: {}", e); + return Err(Error::BadRequest(ErrorKind::Forbidden, "Failed to fetch signing keys.")); + }, + }; - let pub_key_map = - BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]); + let pub_key_map = BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]); - match ruma::signatures::verify_json(&pub_key_map, &request_map) { - Ok(()) => (None, None, Some(x_matrix.origin), false), - Err(e) => { - warn!( - "Failed to verify json request from {}: {}\n{:?}", - x_matrix.origin, e, request_map - ); + match ruma::signatures::verify_json(&pub_key_map, &request_map) { + Ok(()) => (None, None, Some(x_matrix.origin), false), + Err(e) => { + warn!( + "Failed to verify json request from {}: {}\n{:?}", + x_matrix.origin, e, request_map + ); - if parts.uri.to_string().contains('@') { - warn!( - "Request uri contained '@' character. Make sure your \ - reverse proxy gives Conduit the raw uri (apache: use \ - nocanon)" - ); - } + if parts.uri.to_string().contains('@') { + warn!( + "Request uri contained '@' character. Make sure your reverse proxy gives Conduit \ + the raw uri (apache: use nocanon)" + ); + } - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Failed to verify X-Matrix signatures.", - )); - } - } - } - AuthScheme::None => match parts.uri.path() { - // allow_public_room_directory_without_auth - "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { - if !services() - .globals - .config - .allow_public_room_directory_without_auth - { - let token = match token { - Some(token) => token, - _ => { - return Err(Error::BadRequest( - ErrorKind::MissingToken, - "Missing access token.", - )) - } - }; + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Failed to verify X-Matrix signatures.", + )); + }, + } + }, + AuthScheme::None => match parts.uri.path() { + // allow_public_room_directory_without_auth + "/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => { + if !services().globals.config.allow_public_room_directory_without_auth { + let token = match token { + Some(token) => token, + _ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")), + }; - match services().users.find_from_token(token).unwrap() { - None => { - return Err(Error::BadRequest( - ErrorKind::UnknownToken { soft_logout: false }, - "Unknown access token.", - )) - } - Some((user_id, device_id)) => ( - Some(user_id), - Some(OwnedDeviceId::from(device_id)), - None, - false, - ), - } - } else { - (None, None, None, false) - } - } - _ => (None, None, None, false), - }, - } - }; + match services().users.find_from_token(token).unwrap() { + None => { + return Err(Error::BadRequest( + ErrorKind::UnknownToken { + soft_logout: false, + }, + "Unknown access token.", + )) + }, + Some((user_id, device_id)) => { + (Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false) + }, + } + } else { + (None, None, None, false) + } + }, + _ => (None, None, None, false), + }, + } + }; - let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method); - *http_request.headers_mut().unwrap() = parts.headers; + let mut http_request = http::Request::builder().uri(parts.uri).method(parts.method); + *http_request.headers_mut().unwrap() = parts.headers; - if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { - let user_id = sender_user.clone().unwrap_or_else(|| { - UserId::parse_with_server_name("", services().globals.server_name()) - .expect("we know this is valid") - }); + if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body { + let user_id = sender_user.clone().unwrap_or_else(|| { + UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid") + }); - let uiaa_request = json_body - .get("auth") - .and_then(|auth| auth.as_object()) - .and_then(|auth| auth.get("session")) - .and_then(|session| session.as_str()) - .and_then(|session| { - services().uiaa.get_uiaa_request( - &user_id, - &sender_device.clone().unwrap_or_else(|| "".into()), - session, - ) - }); + let uiaa_request = json_body + .get("auth") + .and_then(|auth| auth.as_object()) + .and_then(|auth| auth.get("session")) + .and_then(|session| session.as_str()) + .and_then(|session| { + services().uiaa.get_uiaa_request( + &user_id, + &sender_device.clone().unwrap_or_else(|| "".into()), + session, + ) + }); - if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { - for (key, value) in initial_request { - json_body.entry(key).or_insert(value); - } - } + if let Some(CanonicalJsonValue::Object(initial_request)) = uiaa_request { + for (key, value) in initial_request { + json_body.entry(key).or_insert(value); + } + } - let mut buf = BytesMut::new().writer(); - serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail"); - body = buf.into_inner().freeze(); - } + let mut buf = BytesMut::new().writer(); + serde_json::to_writer(&mut buf, json_body).expect("value serialization can't fail"); + body = buf.into_inner().freeze(); + } - let http_request = http_request.body(&*body).unwrap(); + let http_request = http_request.body(&*body).unwrap(); - debug!("{:?}", http_request); + debug!("{:?}", http_request); - let body = T::try_from_http_request(http_request, &path_params).map_err(|e| { - warn!("try_from_http_request failed: {:?}", e); - debug!("JSON body: {:?}", json_body); - Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") - })?; + let body = T::try_from_http_request(http_request, &path_params).map_err(|e| { + warn!("try_from_http_request failed: {:?}", e); + debug!("JSON body: {:?}", json_body); + Error::BadRequest(ErrorKind::BadJson, "Failed to deserialize request.") + })?; - Ok(Ruma { - body, - sender_user, - sender_device, - sender_servername, - from_appservice, - json_body, - }) - } + Ok(Ruma { + body, + sender_user, + sender_device, + sender_servername, + from_appservice, + json_body, + }) + } } struct XMatrix { - origin: OwnedServerName, - destination: Option, - key: String, // KeyName? - sig: String, + origin: OwnedServerName, + destination: Option, + key: String, // KeyName? + sig: String, } impl Credentials for XMatrix { - const SCHEME: &'static str = "X-Matrix"; + const SCHEME: &'static str = "X-Matrix"; - fn decode(value: &http::HeaderValue) -> Option { - debug_assert!( - value.as_bytes().starts_with(b"X-Matrix "), - "HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", - ); + fn decode(value: &http::HeaderValue) -> Option { + debug_assert!( + value.as_bytes().starts_with(b"X-Matrix "), + "HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}", + ); - let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]) - .ok()? - .trim_start(); + let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]).ok()?.trim_start(); - let mut origin = None; - let mut destination = None; - let mut key = None; - let mut sig = None; + let mut origin = None; + let mut destination = None; + let mut key = None; + let mut sig = None; - for entry in parameters.split_terminator(',') { - let (name, value) = entry.split_once('=')?; + for entry in parameters.split_terminator(',') { + let (name, value) = entry.split_once('=')?; - // It's not at all clear why some fields are quoted and others not in the spec, - // let's simply accept either form for every field. - let value = value - .strip_prefix('"') - .and_then(|rest| rest.strip_suffix('"')) - .unwrap_or(value); + // It's not at all clear why some fields are quoted and others not in the spec, + // let's simply accept either form for every field. + let value = value.strip_prefix('"').and_then(|rest| rest.strip_suffix('"')).unwrap_or(value); - // FIXME: Catch multiple fields of the same name - match name { - "origin" => origin = Some(value.try_into().ok()?), - "key" => key = Some(value.to_owned()), - "sig" => sig = Some(value.to_owned()), - "destination" => destination = Some(value.to_owned()), - _ => debug!( - "Unexpected field `{}` in X-Matrix Authorization header", - name - ), - } - } + // FIXME: Catch multiple fields of the same name + match name { + "origin" => origin = Some(value.try_into().ok()?), + "key" => key = Some(value.to_owned()), + "sig" => sig = Some(value.to_owned()), + "destination" => destination = Some(value.to_owned()), + _ => debug!("Unexpected field `{}` in X-Matrix Authorization header", name), + } + } - Some(Self { - origin: origin?, - key: key?, - sig: sig?, - destination, - }) - } + Some(Self { + origin: origin?, + key: key?, + sig: sig?, + destination, + }) + } - fn encode(&self) -> http::HeaderValue { - todo!() - } + fn encode(&self) -> http::HeaderValue { todo!() } } impl IntoResponse for RumaResponse { - fn into_response(self) -> Response { - match self.0.try_into_http_response::() { - Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), - Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), - } - } + fn into_response(self) -> Response { + match self.0.try_into_http_response::() { + Ok(res) => res.map(BytesMut::freeze).map(Full::new).into_response(), + Err(_) => StatusCode::INTERNAL_SERVER_ERROR.into_response(), + } + } } // copied from hyper under the following license: @@ -443,32 +376,32 @@ impl IntoResponse for RumaResponse { // THE SOFTWARE. pub(crate) async fn to_bytes(body: T) -> Result where - T: HttpBody, + T: HttpBody, { - futures_util::pin_mut!(body); + futures_util::pin_mut!(body); - // If there's only 1 chunk, we can just return Buf::to_bytes() - let mut first = if let Some(buf) = body.data().await { - buf? - } else { - return Ok(Bytes::new()); - }; + // If there's only 1 chunk, we can just return Buf::to_bytes() + let mut first = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(Bytes::new()); + }; - let second = if let Some(buf) = body.data().await { - buf? - } else { - return Ok(first.copy_to_bytes(first.remaining())); - }; + let second = if let Some(buf) = body.data().await { + buf? + } else { + return Ok(first.copy_to_bytes(first.remaining())); + }; - // With more than 1 buf, we gotta flatten into a Vec first. - let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; - let mut vec = Vec::with_capacity(cap); - vec.put(first); - vec.put(second); + // With more than 1 buf, we gotta flatten into a Vec first. + let cap = first.remaining() + second.remaining() + body.size_hint().lower() as usize; + let mut vec = Vec::with_capacity(cap); + vec.put(first); + vec.put(second); - while let Some(buf) = body.data().await { - vec.put(buf?); - } + while let Some(buf) = body.data().await { + vec.put(buf?); + } - Ok(vec.into()) + Ok(vec.into()) } diff --git a/src/api/ruma_wrapper/mod.rs b/src/api/ruma_wrapper/mod.rs index ac4c825a..94c7dd8e 100644 --- a/src/api/ruma_wrapper/mod.rs +++ b/src/api/ruma_wrapper/mod.rs @@ -1,43 +1,36 @@ -use crate::Error; -use ruma::{ - api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, - OwnedUserId, -}; use std::ops::Deref; +use ruma::{api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId}; + +use crate::Error; + #[cfg(feature = "conduit_bin")] mod axum; /// Extractor for Ruma request structs pub struct Ruma { - pub body: T, - pub sender_user: Option, - pub sender_device: Option, - pub sender_servername: Option, - // This is None when body is not a valid string - pub json_body: Option, - pub from_appservice: bool, + pub body: T, + pub sender_user: Option, + pub sender_device: Option, + pub sender_servername: Option, + // This is None when body is not a valid string + pub json_body: Option, + pub from_appservice: bool, } impl Deref for Ruma { - type Target = T; + type Target = T; - fn deref(&self) -> &Self::Target { - &self.body - } + fn deref(&self) -> &Self::Target { &self.body } } #[derive(Clone)] pub struct RumaResponse(pub T); impl From for RumaResponse { - fn from(t: T) -> Self { - Self(t) - } + fn from(t: T) -> Self { Self(t) } } impl From for RumaResponse { - fn from(t: Error) -> Self { - t.to_response() - } + fn from(t: Error) -> Self { t.to_response() } } diff --git a/src/api/server_server.rs b/src/api/server_server.rs index e05bffb5..d33d3f2b 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1,65 +1,63 @@ #![allow(deprecated)] // Conduit implements the older APIs -use crate::{ - api::client_server::{self, claim_keys_helper, get_keys_helper}, - service::pdu::{gen_event_id_canonical_json, PduBuilder}, - services, utils, Error, PduEvent, Result, Ruma, +use std::{ + collections::BTreeMap, + fmt::Debug, + mem, + net::{IpAddr, SocketAddr}, + sync::{Arc, RwLock}, + time::{Duration, Instant, SystemTime}, }; + use axum::{response::IntoResponse, Json}; use futures_util::future::TryFutureExt; use get_profile_information::v1::ProfileField; use http::header::{HeaderValue, AUTHORIZATION}; - use ipaddress::IPAddress; use ruma::{ - api::{ - client::error::{Error as RumaError, ErrorKind}, - federation::{ - authorization::get_event_authorization, - backfill::get_backfill, - device::get_devices::{self, v1::UserDevice}, - directory::{get_public_rooms, get_public_rooms_filtered}, - discovery::{get_server_keys, get_server_version, ServerSigningKeys, VerifyKey}, - event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, - keys::{claim_keys, get_keys}, - membership::{create_invite, create_join_event, prepare_join_event}, - query::{get_profile_information, get_room_information}, - transactions::{ - edu::{DeviceListUpdateContent, DirectDeviceContent, Edu, SigningKeyUpdateContent}, - send_transaction_message, - }, - }, - EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, OutgoingResponse, - SendAccessToken, - }, - directory::{Filter, RoomNetwork}, - events::{ - receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType}, - room::{ - join_rules::{JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, - }, - StateEventType, TimelineEventType, - }, - serde::{Base64, JsonObject, Raw}, - to_device::DeviceIdOrAllDevices, - uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, - OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId, - ServerName, + api::{ + client::error::{Error as RumaError, ErrorKind}, + federation::{ + authorization::get_event_authorization, + backfill::get_backfill, + device::get_devices::{self, v1::UserDevice}, + directory::{get_public_rooms, get_public_rooms_filtered}, + discovery::{get_server_keys, get_server_version, ServerSigningKeys, VerifyKey}, + event::{get_event, get_missing_events, get_room_state, get_room_state_ids}, + keys::{claim_keys, get_keys}, + membership::{create_invite, create_join_event, prepare_join_event}, + query::{get_profile_information, get_room_information}, + transactions::{ + edu::{DeviceListUpdateContent, DirectDeviceContent, Edu, SigningKeyUpdateContent}, + send_transaction_message, + }, + }, + EndpointError, IncomingResponse, MatrixVersion, OutgoingRequest, OutgoingResponse, SendAccessToken, + }, + directory::{Filter, RoomNetwork}, + events::{ + receipt::{ReceiptEvent, ReceiptEventContent, ReceiptType}, + room::{ + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + }, + StateEventType, TimelineEventType, + }, + serde::{Base64, JsonObject, Raw}, + to_device::DeviceIdOrAllDevices, + uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, + OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, RoomId, ServerName, }; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; -use std::{ - collections::BTreeMap, - fmt::Debug, - mem, - net::{IpAddr, SocketAddr}, - sync::{Arc, RwLock}, - time::{Duration, Instant, SystemTime}, -}; +use tracing::{debug, error, info, warn}; use trust_dns_resolver::{error::ResolveError, lookup::SrvLookup}; -use tracing::{debug, error, info, warn}; +use crate::{ + api::client_server::{self, claim_keys_helper, get_keys_helper}, + service::pdu::{gen_event_id_canonical_json, PduBuilder}, + services, utils, Error, PduEvent, Result, Ruma, +}; /// Wraps either an literal IP address plus port, or a hostname plus complement /// (colon-plus-port if it was specified). @@ -81,1247 +79,1033 @@ use tracing::{debug, error, info, warn}; /// ``` #[derive(Clone, Debug, PartialEq, Eq)] pub enum FedDest { - Literal(SocketAddr), - Named(String, String), + Literal(SocketAddr), + Named(String, String), } impl FedDest { - fn into_https_string(self) -> String { - match self { - Self::Literal(addr) => format!("https://{addr}"), - Self::Named(host, port) => format!("https://{host}{port}"), - } - } + fn into_https_string(self) -> String { + match self { + Self::Literal(addr) => format!("https://{addr}"), + Self::Named(host, port) => format!("https://{host}{port}"), + } + } - fn into_uri_string(self) -> String { - match self { - Self::Literal(addr) => addr.to_string(), - Self::Named(host, port) => host + &port, - } - } + fn into_uri_string(self) -> String { + match self { + Self::Literal(addr) => addr.to_string(), + Self::Named(host, port) => host + &port, + } + } - fn hostname(&self) -> String { - match &self { - Self::Literal(addr) => addr.ip().to_string(), - Self::Named(host, _) => host.clone(), - } - } + fn hostname(&self) -> String { + match &self { + Self::Literal(addr) => addr.ip().to_string(), + Self::Named(host, _) => host.clone(), + } + } - fn port(&self) -> Option { - match &self { - Self::Literal(addr) => Some(addr.port()), - Self::Named(_, port) => port[1..].parse().ok(), - } - } + fn port(&self) -> Option { + match &self { + Self::Literal(addr) => Some(addr.port()), + Self::Named(_, port) => port[1..].parse().ok(), + } + } } -pub(crate) async fn send_request( - destination: &ServerName, - request: T, -) -> Result +pub(crate) async fn send_request(destination: &ServerName, request: T) -> Result where - T: OutgoingRequest + Debug, + T: OutgoingRequest + Debug, { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - if destination == services().globals.server_name() { - return Err(Error::bad_config( - "Won't send federation request to ourselves", - )); - } + if destination == services().globals.server_name() { + return Err(Error::bad_config("Won't send federation request to ourselves")); + } - if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) { - info!( - "Destination {} is an IP literal, checking against IP range denylist.", - destination - ); - let ip = IPAddress::parse(destination.host()).map_err(|e| { - warn!("Failed to parse IP literal from string: {}", e); - Error::BadServerResponse("Invalid IP address") - })?; + if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) { + info!( + "Destination {} is an IP literal, checking against IP range denylist.", + destination + ); + let ip = IPAddress::parse(destination.host()).map_err(|e| { + warn!("Failed to parse IP literal from string: {}", e); + Error::BadServerResponse("Invalid IP address") + })?; - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); + let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); + let mut cidr_ranges: Vec = Vec::new(); - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } + for cidr in cidr_ranges_s { + cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); + } - debug!("List of pushed CIDR ranges: {:?}", cidr_ranges); + debug!("List of pushed CIDR ranges: {:?}", cidr_ranges); - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadServerResponse( - "Not allowed to send requests to this IP", - )); - } - } + for cidr in cidr_ranges { + if cidr.includes(&ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } + } - info!("IP literal {} is allowed.", destination); - } + info!("IP literal {} is allowed.", destination); + } - debug!("Preparing to send request to {destination}"); + debug!("Preparing to send request to {destination}"); - let mut write_destination_to_cache = false; + let mut write_destination_to_cache = false; - let cached_result = services() - .globals - .actual_destination_cache - .read() - .unwrap() - .get(destination) - .cloned(); + let cached_result = services().globals.actual_destination_cache.read().unwrap().get(destination).cloned(); - let (actual_destination, host) = if let Some(result) = cached_result { - result - } else { - write_destination_to_cache = true; + let (actual_destination, host) = if let Some(result) = cached_result { + result + } else { + write_destination_to_cache = true; - let result = find_actual_destination(destination).await; + let result = find_actual_destination(destination).await; - (result.0, result.1.into_uri_string()) - }; + (result.0, result.1.into_uri_string()) + }; - let actual_destination_str = actual_destination.clone().into_https_string(); + let actual_destination_str = actual_destination.clone().into_https_string(); - let mut http_request = request - .try_into_http_request::>( - &actual_destination_str, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_5], - ) - .map_err(|e| { - warn!( - "Failed to find destination {}: {}", - actual_destination_str, e - ); - Error::BadServerResponse("Invalid destination") - })?; + let mut http_request = request + .try_into_http_request::>( + &actual_destination_str, + SendAccessToken::IfRequired(""), + &[MatrixVersion::V1_5], + ) + .map_err(|e| { + warn!("Failed to find destination {}: {}", actual_destination_str, e); + Error::BadServerResponse("Invalid destination") + })?; - let mut request_map = serde_json::Map::new(); + let mut request_map = serde_json::Map::new(); - if !http_request.body().is_empty() { - request_map.insert( - "content".to_owned(), - serde_json::from_slice(http_request.body()) - .expect("body is valid json, we just created it"), - ); - }; + if !http_request.body().is_empty() { + request_map.insert( + "content".to_owned(), + serde_json::from_slice(http_request.body()).expect("body is valid json, we just created it"), + ); + }; - request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); - request_map.insert( - "uri".to_owned(), - http_request - .uri() - .path_and_query() - .expect("all requests have a path") - .to_string() - .into(), - ); - request_map.insert( - "origin".to_owned(), - services().globals.server_name().as_str().into(), - ); - request_map.insert("destination".to_owned(), destination.as_str().into()); + request_map.insert("method".to_owned(), T::METADATA.method.to_string().into()); + request_map.insert( + "uri".to_owned(), + http_request.uri().path_and_query().expect("all requests have a path").to_string().into(), + ); + request_map.insert("origin".to_owned(), services().globals.server_name().as_str().into()); + request_map.insert("destination".to_owned(), destination.as_str().into()); - let mut request_json = - serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap"); + let mut request_json = serde_json::from_value(request_map.into()).expect("valid JSON is valid BTreeMap"); - ruma::signatures::sign_json( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut request_json, - ) - .expect("our request json is what ruma expects"); + ruma::signatures::sign_json( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut request_json, + ) + .expect("our request json is what ruma expects"); - let request_json: serde_json::Map = - serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()).unwrap(); + let request_json: serde_json::Map = + serde_json::from_slice(&serde_json::to_vec(&request_json).unwrap()).unwrap(); - let signatures = request_json["signatures"] - .as_object() - .unwrap() - .values() - .map(|v| { - v.as_object() - .unwrap() - .iter() - .map(|(k, v)| (k, v.as_str().unwrap())) - }); + let signatures = request_json["signatures"] + .as_object() + .unwrap() + .values() + .map(|v| v.as_object().unwrap().iter().map(|(k, v)| (k, v.as_str().unwrap()))); - for signature_server in signatures { - for s in signature_server { - http_request.headers_mut().insert( - AUTHORIZATION, - HeaderValue::from_str(&format!( - "X-Matrix origin={},key=\"{}\",sig=\"{}\"", - services().globals.server_name(), - s.0, - s.1 - )) - .unwrap(), - ); - } - } + for signature_server in signatures { + for s in signature_server { + http_request.headers_mut().insert( + AUTHORIZATION, + HeaderValue::from_str(&format!( + "X-Matrix origin={},key=\"{}\",sig=\"{}\"", + services().globals.server_name(), + s.0, + s.1 + )) + .unwrap(), + ); + } + } - let reqwest_request = reqwest::Request::try_from(http_request)?; + let reqwest_request = reqwest::Request::try_from(http_request)?; - let url = reqwest_request.url().clone(); + let url = reqwest_request.url().clone(); - debug!("Sending request to {destination} at {url}"); - let response = services() - .globals - .federation_client() - .execute(reqwest_request) - .await; - debug!("Received response from {destination} at {url}"); + debug!("Sending request to {destination} at {url}"); + let response = services().globals.federation_client().execute(reqwest_request).await; + debug!("Received response from {destination} at {url}"); - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); + match response { + Ok(mut response) => { + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder().status(status).version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder.headers_mut().expect("http::response::Builder is usable"), + ); - debug!("Getting response bytes from {destination}"); - let body = response.bytes().await.unwrap_or_else(|e| { - info!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout - debug!("Got response bytes from {destination}"); + debug!("Getting response bytes from {destination}"); + let body = response.bytes().await.unwrap_or_else(|e| { + info!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout + debug!("Got response bytes from {destination}"); - if !status.is_success() { - debug!( - "Response not successful\n{} {}: {}", - url, - status, - String::from_utf8_lossy(&body) - .lines() - .collect::>() - .join(" ") - ); - } + if !status.is_success() { + debug!( + "Response not successful\n{} {}: {}", + url, + status, + String::from_utf8_lossy(&body).lines().collect::>().join(" ") + ); + } - let http_response = http_response_builder - .body(body) - .expect("reqwest body is valid http body"); + let http_response = http_response_builder.body(body).expect("reqwest body is valid http body"); - if status.is_success() { - debug!("Parsing response bytes from {destination}"); - let response = T::IncomingResponse::try_from_http_response(http_response); - if response.is_ok() && write_destination_to_cache { - services() - .globals - .actual_destination_cache - .write() - .unwrap() - .insert( - OwnedServerName::from(destination), - (actual_destination, host), - ); - } + if status.is_success() { + debug!("Parsing response bytes from {destination}"); + let response = T::IncomingResponse::try_from_http_response(http_response); + if response.is_ok() && write_destination_to_cache { + services() + .globals + .actual_destination_cache + .write() + .unwrap() + .insert(OwnedServerName::from(destination), (actual_destination, host)); + } - response.map_err(|e| { - warn!( - "Invalid 200 response from {} on: {} {}", - &destination, url, e - ); - Error::BadServerResponse("Server returned bad 200 response.") - }) - } else { - debug!("Returning error from {destination}"); + response.map_err(|e| { + warn!("Invalid 200 response from {} on: {} {}", &destination, url, e); + Error::BadServerResponse("Server returned bad 200 response.") + }) + } else { + debug!("Returning error from {destination}"); - // remove potentially dead destinations from our cache that may be from modified well-knowns - if !write_destination_to_cache { - info!("Evicting {destination} from our true destination cache due to failed request."); - services() - .globals - .actual_destination_cache - .write() - .unwrap() - .remove(destination); - } + // remove potentially dead destinations from our cache that may be from modified + // well-knowns + if !write_destination_to_cache { + info!("Evicting {destination} from our true destination cache due to failed request."); + services().globals.actual_destination_cache.write().unwrap().remove(destination); + } - Err(Error::FederationError( - destination.to_owned(), - RumaError::from_http_response(http_response), - )) - } - } - Err(e) => { - // we do not need to log that servers in a room are dead, this is normal in public rooms and just spams the logs. - match e.is_timeout() { - true => info!( - "Timed out sending request to {} at {}: {}", - destination, actual_destination_str, e - ), - false => match e.is_connect() { - true => info!( - "Failed to connect to {} at {}: {}", - destination, actual_destination_str, e - ), - false => match e.is_redirect() { - true => info!( - "Redirect loop sending request to {} at {}: {}\nFinal URL: {:?}", - destination, - actual_destination_str, - e, - e.url() - ), - false => warn!( - "Could not send request to {} at {}: {}", - destination, actual_destination_str, e - ), - }, - }, - } - Err(e.into()) - } - } + Err(Error::FederationError( + destination.to_owned(), + RumaError::from_http_response(http_response), + )) + } + }, + Err(e) => { + // we do not need to log that servers in a room are dead, this is normal in + // public rooms and just spams the logs. + match e.is_timeout() { + true => info!( + "Timed out sending request to {} at {}: {}", + destination, actual_destination_str, e + ), + false => match e.is_connect() { + true => info!("Failed to connect to {} at {}: {}", destination, actual_destination_str, e), + false => match e.is_redirect() { + true => info!( + "Redirect loop sending request to {} at {}: {}\nFinal URL: {:?}", + destination, + actual_destination_str, + e, + e.url() + ), + false => { + warn!("Could not send request to {} at {}: {}", destination, actual_destination_str, e) + }, + }, + }, + } + Err(e.into()) + }, + } } fn get_ip_with_port(destination_str: &str) -> Option { - if let Ok(destination) = destination_str.parse::() { - Some(FedDest::Literal(destination)) - } else if let Ok(ip_addr) = destination_str.parse::() { - Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) - } else { - None - } + if let Ok(destination) = destination_str.parse::() { + Some(FedDest::Literal(destination)) + } else if let Ok(ip_addr) = destination_str.parse::() { + Some(FedDest::Literal(SocketAddr::new(ip_addr, 8448))) + } else { + None + } } fn add_port_to_hostname(destination_str: &str) -> FedDest { - let (host, port) = match destination_str.find(':') { - None => (destination_str, ":8448"), - Some(pos) => destination_str.split_at(pos), - }; - FedDest::Named(host.to_owned(), port.to_owned()) + let (host, port) = match destination_str.find(':') { + None => (destination_str, ":8448"), + Some(pos) => destination_str.split_at(pos), + }; + FedDest::Named(host.to_owned(), port.to_owned()) } /// Returns: actual_destination, host header /// Implemented according to the specification at -/// Numbers in comments below refer to bullet points in linked section of specification +/// Numbers in comments below refer to bullet points in linked section of +/// specification async fn find_actual_destination(destination: &'_ ServerName) -> (FedDest, FedDest) { - debug!("Finding actual destination for {destination}"); - let destination_str = destination.as_str().to_owned(); - let mut hostname = destination_str.clone(); - let actual_destination = match get_ip_with_port(&destination_str) { - Some(host_port) => { - debug!("1: IP literal with provided or default port"); - host_port - } - None => { - if let Some(pos) = destination_str.find(':') { - debug!("2: Hostname with included port"); - let (host, port) = destination_str.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - debug!("Requesting well known for {destination}"); - match request_well_known(destination.as_str()).await { - Some(delegated_hostname) => { - debug!("3: A .well-known file is available"); - hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); - match get_ip_with_port(&delegated_hostname) { - Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file - None => { - if let Some(pos) = delegated_hostname.find(':') { - debug!("3.2: Hostname with port in .well-known file"); - let (host, port) = delegated_hostname.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - debug!("Delegated hostname has no port in this branch"); - if let Some(hostname_override) = - query_srv_record(&delegated_hostname).await - { - debug!("3.3: SRV lookup successful"); - let force_port = hostname_override.port(); + debug!("Finding actual destination for {destination}"); + let destination_str = destination.as_str().to_owned(); + let mut hostname = destination_str.clone(); + let actual_destination = match get_ip_with_port(&destination_str) { + Some(host_port) => { + debug!("1: IP literal with provided or default port"); + host_port + }, + None => { + if let Some(pos) = destination_str.find(':') { + debug!("2: Hostname with included port"); + let (host, port) = destination_str.split_at(pos); + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + debug!("Requesting well known for {destination}"); + match request_well_known(destination.as_str()).await { + Some(delegated_hostname) => { + debug!("3: A .well-known file is available"); + hostname = add_port_to_hostname(&delegated_hostname).into_uri_string(); + match get_ip_with_port(&delegated_hostname) { + Some(host_and_port) => host_and_port, // 3.1: IP literal in .well-known file + None => { + if let Some(pos) = delegated_hostname.find(':') { + debug!("3.2: Hostname with port in .well-known file"); + let (host, port) = delegated_hostname.split_at(pos); + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + debug!("Delegated hostname has no port in this branch"); + if let Some(hostname_override) = query_srv_record(&delegated_hostname).await { + debug!("3.3: SRV lookup successful"); + let force_port = hostname_override.port(); - if let Ok(override_ip) = services() - .globals - .dns_resolver() - .lookup_ip(hostname_override.hostname()) - .await - { - services() - .globals - .tls_name_override - .write() - .unwrap() - .insert( - delegated_hostname.clone(), - ( - override_ip.iter().collect(), - force_port.unwrap_or(8448), - ), - ); - } else { - debug!( - "Using SRV record {}, but could not resolve to IP", - hostname_override.hostname() - ); - } + if let Ok(override_ip) = services() + .globals + .dns_resolver() + .lookup_ip(hostname_override.hostname()) + .await + { + services().globals.tls_name_override.write().unwrap().insert( + delegated_hostname.clone(), + (override_ip.iter().collect(), force_port.unwrap_or(8448)), + ); + } else { + debug!( + "Using SRV record {}, but could not resolve to IP", + hostname_override.hostname() + ); + } - if let Some(port) = force_port { - FedDest::Named(delegated_hostname, format!(":{port}")) - } else { - add_port_to_hostname(&delegated_hostname) - } - } else { - debug!("3.4: No SRV records, just use the hostname from .well-known"); - add_port_to_hostname(&delegated_hostname) - } - } - } - } - } - None => { - debug!("4: No .well-known or an error occured"); - match query_srv_record(&destination_str).await { - Some(hostname_override) => { - debug!("4: SRV record found"); - let force_port = hostname_override.port(); + if let Some(port) = force_port { + FedDest::Named(delegated_hostname, format!(":{port}")) + } else { + add_port_to_hostname(&delegated_hostname) + } + } else { + debug!("3.4: No SRV records, just use the hostname from .well-known"); + add_port_to_hostname(&delegated_hostname) + } + } + }, + } + }, + None => { + debug!("4: No .well-known or an error occured"); + match query_srv_record(&destination_str).await { + Some(hostname_override) => { + debug!("4: SRV record found"); + let force_port = hostname_override.port(); - if let Ok(override_ip) = services() - .globals - .dns_resolver() - .lookup_ip(hostname_override.hostname()) - .await - { - services() - .globals - .tls_name_override - .write() - .unwrap() - .insert( - hostname.clone(), - ( - override_ip.iter().collect(), - force_port.unwrap_or(8448), - ), - ); - } else { - debug!( - "Using SRV record {}, but could not resolve to IP", - hostname_override.hostname() - ); - } + if let Ok(override_ip) = + services().globals.dns_resolver().lookup_ip(hostname_override.hostname()).await + { + services().globals.tls_name_override.write().unwrap().insert( + hostname.clone(), + (override_ip.iter().collect(), force_port.unwrap_or(8448)), + ); + } else { + debug!( + "Using SRV record {}, but could not resolve to IP", + hostname_override.hostname() + ); + } - if let Some(port) = force_port { - FedDest::Named(hostname.clone(), format!(":{port}")) - } else { - add_port_to_hostname(&hostname) - } - } - None => { - debug!("5: No SRV record found"); - add_port_to_hostname(&destination_str) - } - } - } - } - } - } - }; - debug!("Actual destination: {actual_destination:?}"); + if let Some(port) = force_port { + FedDest::Named(hostname.clone(), format!(":{port}")) + } else { + add_port_to_hostname(&hostname) + } + }, + None => { + debug!("5: No SRV record found"); + add_port_to_hostname(&destination_str) + }, + } + }, + } + } + }, + }; + debug!("Actual destination: {actual_destination:?}"); - // Can't use get_ip_with_port here because we don't want to add a port - // to an IP address if it wasn't specified - let hostname = if let Ok(addr) = hostname.parse::() { - FedDest::Literal(addr) - } else if let Ok(addr) = hostname.parse::() { - FedDest::Named(addr.to_string(), ":8448".to_owned()) - } else if let Some(pos) = hostname.find(':') { - let (host, port) = hostname.split_at(pos); - FedDest::Named(host.to_owned(), port.to_owned()) - } else { - FedDest::Named(hostname, ":8448".to_owned()) - }; - (actual_destination, hostname) + // Can't use get_ip_with_port here because we don't want to add a port + // to an IP address if it wasn't specified + let hostname = if let Ok(addr) = hostname.parse::() { + FedDest::Literal(addr) + } else if let Ok(addr) = hostname.parse::() { + FedDest::Named(addr.to_string(), ":8448".to_owned()) + } else if let Some(pos) = hostname.find(':') { + let (host, port) = hostname.split_at(pos); + FedDest::Named(host.to_owned(), port.to_owned()) + } else { + FedDest::Named(hostname, ":8448".to_owned()) + }; + (actual_destination, hostname) } async fn query_srv_record(hostname: &'_ str) -> Option { - fn handle_successful_srv(srv: SrvLookup) -> Option { - srv.iter().next().map(|result| { - FedDest::Named( - result.target().to_string().trim_end_matches('.').to_owned(), - format!(":{}", result.port()), - ) - }) - } + fn handle_successful_srv(srv: SrvLookup) -> Option { + srv.iter().next().map(|result| { + FedDest::Named( + result.target().to_string().trim_end_matches('.').to_owned(), + format!(":{}", result.port()), + ) + }) + } - async fn lookup_srv(hostname: &str) -> Result { - debug!("querying SRV for {:?}", hostname); - let hostname = hostname.trim_end_matches('.'); - services() - .globals - .dns_resolver() - .srv_lookup(hostname.to_owned()) - .await - } + async fn lookup_srv(hostname: &str) -> Result { + debug!("querying SRV for {:?}", hostname); + let hostname = hostname.trim_end_matches('.'); + services().globals.dns_resolver().srv_lookup(hostname.to_owned()).await + } - let first_hostname = format!("_matrix-fed._tcp.{hostname}."); - let second_hostname = format!("_matrix._tcp.{hostname}."); + let first_hostname = format!("_matrix-fed._tcp.{hostname}."); + let second_hostname = format!("_matrix._tcp.{hostname}."); - lookup_srv(&first_hostname) - .or_else(|_| { - info!( - "Querying deprecated _matrix SRV record for host {:?}", - hostname - ); - lookup_srv(&second_hostname) - }) - .and_then(|srv_lookup| async { Ok(handle_successful_srv(srv_lookup)) }) - .await - .ok() - .flatten() + lookup_srv(&first_hostname) + .or_else(|_| { + info!("Querying deprecated _matrix SRV record for host {:?}", hostname); + lookup_srv(&second_hostname) + }) + .and_then(|srv_lookup| async { Ok(handle_successful_srv(srv_lookup)) }) + .await + .ok() + .flatten() } async fn request_well_known(destination: &str) -> Option { - let response = services() - .globals - .default_client() - .get(&format!("https://{destination}/.well-known/matrix/server")) - .send() - .await; - debug!("Got well known response"); - debug!("Well known response: {:?}", response); + let response = services() + .globals + .default_client() + .get(&format!("https://{destination}/.well-known/matrix/server")) + .send() + .await; + debug!("Got well known response"); + debug!("Well known response: {:?}", response); - if let Err(e) = &response { - debug!("Well known error: {e:?}"); - return None; - } + if let Err(e) = &response { + debug!("Well known error: {e:?}"); + return None; + } - let text = response.ok()?.text().await; + let text = response.ok()?.text().await; - debug!("Got well known response text"); - debug!("Well known response text: {:?}", text); + debug!("Got well known response text"); + debug!("Well known response text: {:?}", text); - if text.as_ref().ok()?.len() > 10000 { - info!("Well known response for destination '{destination}' exceeded past 10000 characters, assuming no well-known."); - return None; - } + if text.as_ref().ok()?.len() > 10000 { + info!( + "Well known response for destination '{destination}' exceeded past 10000 characters, assuming no \ + well-known." + ); + return None; + } - let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?; - debug!("serde_json body of well known text: {}", body); + let body: serde_json::Value = serde_json::from_str(&text.ok()?).ok()?; + debug!("serde_json body of well known text: {}", body); - Some(body.get("m.server")?.as_str()?.to_owned()) + Some(body.get("m.server")?.as_str()?.to_owned()) } /// # `GET /_matrix/federation/v1/version` /// /// Get version information on this server. pub async fn get_server_version_route( - _body: Ruma, + _body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - Ok(get_server_version::v1::Response { - server: Some(get_server_version::v1::Server { - name: Some("Conduwuit".to_owned()), - version: Some(env!("CARGO_PKG_VERSION").to_owned()), - }), - }) + Ok(get_server_version::v1::Response { + server: Some(get_server_version::v1::Server { + name: Some("Conduwuit".to_owned()), + version: Some(env!("CARGO_PKG_VERSION").to_owned()), + }), + }) } /// # `GET /_matrix/key/v2/server` /// /// Gets the public signing keys of this server. /// -/// - Matrix does not support invalidating public keys, so the key returned by this will be valid +/// - Matrix does not support invalidating public keys, so the key returned by +/// this will be valid /// forever. -// Response type for this endpoint is Json because we need to calculate a signature for the response +// Response type for this endpoint is Json because we need to calculate a +// signature for the response pub async fn get_server_keys_route() -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - let mut verify_keys: BTreeMap = BTreeMap::new(); - verify_keys.insert( - format!("ed25519:{}", services().globals.keypair().version()) - .try_into() - .expect("found invalid server signing keys in DB"), - VerifyKey { - key: Base64::new(services().globals.keypair().public_key().to_vec()), - }, - ); - let mut response = serde_json::from_slice( - get_server_keys::v2::Response { - server_key: Raw::new(&ServerSigningKeys { - server_name: services().globals.server_name().to_owned(), - verify_keys, - old_verify_keys: BTreeMap::new(), - signatures: BTreeMap::new(), - valid_until_ts: MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() + Duration::from_secs(86400 * 7), - ) - .expect("time is valid"), - }) - .expect("static conversion, no errors"), - } - .try_into_http_response::>() - .unwrap() - .body(), - ) - .unwrap(); + let mut verify_keys: BTreeMap = BTreeMap::new(); + verify_keys.insert( + format!("ed25519:{}", services().globals.keypair().version()) + .try_into() + .expect("found invalid server signing keys in DB"), + VerifyKey { + key: Base64::new(services().globals.keypair().public_key().to_vec()), + }, + ); + let mut response = serde_json::from_slice( + get_server_keys::v2::Response { + server_key: Raw::new(&ServerSigningKeys { + server_name: services().globals.server_name().to_owned(), + verify_keys, + old_verify_keys: BTreeMap::new(), + signatures: BTreeMap::new(), + valid_until_ts: MilliSecondsSinceUnixEpoch::from_system_time( + SystemTime::now() + Duration::from_secs(86400 * 7), + ) + .expect("time is valid"), + }) + .expect("static conversion, no errors"), + } + .try_into_http_response::>() + .unwrap() + .body(), + ) + .unwrap(); - ruma::signatures::sign_json( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut response, - ) - .unwrap(); + ruma::signatures::sign_json( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut response, + ) + .unwrap(); - Ok(Json(response)) + Ok(Json(response)) } /// # `GET /_matrix/key/v2/server/{keyId}` /// /// Gets the public signing keys of this server. /// -/// - Matrix does not support invalidating public keys, so the key returned by this will be valid +/// - Matrix does not support invalidating public keys, so the key returned by +/// this will be valid /// forever. -pub async fn get_server_keys_deprecated_route() -> impl IntoResponse { - get_server_keys_route().await -} +pub async fn get_server_keys_deprecated_route() -> impl IntoResponse { get_server_keys_route().await } /// # `POST /_matrix/federation/v1/publicRooms` /// /// Lists the public rooms on this server. pub async fn get_public_rooms_filtered_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - if !services() - .globals - .allow_public_room_directory_over_federation() - { - return Err(Error::bad_config("Room directory is not public.")); - } + if !services().globals.allow_public_room_directory_over_federation() { + return Err(Error::bad_config("Room directory is not public.")); + } - let response = client_server::get_public_rooms_filtered_helper( - None, - body.limit, - body.since.as_deref(), - &body.filter, - &body.room_network, - ) - .await?; + let response = client_server::get_public_rooms_filtered_helper( + None, + body.limit, + body.since.as_deref(), + &body.filter, + &body.room_network, + ) + .await?; - Ok(get_public_rooms_filtered::v1::Response { - chunk: response.chunk, - prev_batch: response.prev_batch, - next_batch: response.next_batch, - total_room_count_estimate: response.total_room_count_estimate, - }) + Ok(get_public_rooms_filtered::v1::Response { + chunk: response.chunk, + prev_batch: response.prev_batch, + next_batch: response.next_batch, + total_room_count_estimate: response.total_room_count_estimate, + }) } /// # `GET /_matrix/federation/v1/publicRooms` /// /// Lists the public rooms on this server. pub async fn get_public_rooms_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - if !services() - .globals - .allow_public_room_directory_over_federation() - { - return Err(Error::bad_config("Room directory is not public.")); - } + if !services().globals.allow_public_room_directory_over_federation() { + return Err(Error::bad_config("Room directory is not public.")); + } - let response = client_server::get_public_rooms_filtered_helper( - None, - body.limit, - body.since.as_deref(), - &Filter::default(), - &RoomNetwork::Matrix, - ) - .await?; + let response = client_server::get_public_rooms_filtered_helper( + None, + body.limit, + body.since.as_deref(), + &Filter::default(), + &RoomNetwork::Matrix, + ) + .await?; - Ok(get_public_rooms::v1::Response { - chunk: response.chunk, - prev_batch: response.prev_batch, - next_batch: response.next_batch, - total_room_count_estimate: response.total_room_count_estimate, - }) + Ok(get_public_rooms::v1::Response { + chunk: response.chunk, + prev_batch: response.prev_batch, + next_batch: response.next_batch, + total_room_count_estimate: response.total_room_count_estimate, + }) } -pub fn parse_incoming_pdu( - pdu: &RawJsonValue, -) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - warn!("Error parsing incoming event {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; +pub fn parse_incoming_pdu(pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; - let room_id: OwnedRoomId = value - .get("room_id") - .and_then(|id| RoomId::parse(id.as_str()?).ok()) - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid room id in pdu", - ))?; + let room_id: OwnedRoomId = value + .get("room_id") + .and_then(|id| RoomId::parse(id.as_str()?).ok()) + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?; - let room_version_id = services().rooms.state.get_room_version(&room_id)?; + let room_version_id = services().rooms.state.get_room_version(&room_id)?; - let (event_id, value) = match gen_event_id_canonical_json(pdu, &room_version_id) { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - } - }; - Ok((event_id, value, room_id)) + let (event_id, value) = match gen_event_id_canonical_json(pdu, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + }, + }; + Ok((event_id, value, room_id)) } /// # `PUT /_matrix/federation/v1/send/{txnId}` /// /// Push EDUs and PDUs to this server. pub async fn send_transaction_message_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - let mut resolved_map = BTreeMap::new(); + let mut resolved_map = BTreeMap::new(); - let pub_key_map = RwLock::new(BTreeMap::new()); + let pub_key_map = RwLock::new(BTreeMap::new()); - // This is all the auth_events that have been recursively fetched so they don't have to be - // deserialized over and over again. - // TODO: make this persist across requests but not in a DB Tree (in globals?) - // TODO: This could potentially also be some sort of trie (suffix tree) like structure so - // that once an auth event is known it would know (using indexes maybe) all of the auth - // events that it references. - // let mut auth_cache = EventMap::new(); + // This is all the auth_events that have been recursively fetched so they don't + // have to be deserialized over and over again. + // TODO: make this persist across requests but not in a DB Tree (in globals?) + // TODO: This could potentially also be some sort of trie (suffix tree) like + // structure so that once an auth event is known it would know (using indexes + // maybe) all of the auth events that it references. + // let mut auth_cache = EventMap::new(); - let mut parsed_pdus = vec![]; - for pdu in &body.pdus { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - warn!("Error parsing incoming event {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - let room_id: OwnedRoomId = value - .get("room_id") - .and_then(|id| RoomId::parse(id.as_str()?).ok()) - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid room id in pdu", - ))?; + let mut parsed_pdus = vec![]; + for pdu in &body.pdus { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + let room_id: OwnedRoomId = value + .get("room_id") + .and_then(|id| RoomId::parse(id.as_str()?).ok()) + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid room id in pdu"))?; - if services().rooms.state.get_room_version(&room_id).is_err() { - debug!("Server is not in room {room_id}"); - continue; - } + if services().rooms.state.get_room_version(&room_id).is_err() { + debug!("Server is not in room {room_id}"); + continue; + } - let r = parse_incoming_pdu(pdu); - let (event_id, value, room_id) = match r { - Ok(t) => t, - Err(e) => { - warn!("Could not parse PDU: {e}"); - warn!("Full PDU: {:?}", &pdu); - continue; - } - }; - parsed_pdus.push((event_id, value, room_id)); - // We do not add the event_id field to the pdu here because of signature and hashes checks - } + let r = parse_incoming_pdu(pdu); + let (event_id, value, room_id) = match r { + Ok(t) => t, + Err(e) => { + warn!("Could not parse PDU: {e}"); + warn!("Full PDU: {:?}", &pdu); + continue; + }, + }; + parsed_pdus.push((event_id, value, room_id)); + // We do not add the event_id field to the pdu here because of signature + // and hashes checks + } - // We go through all the signatures we see on the PDUs and fetch the corresponding - // signing keys - services() - .rooms - .event_handler - .fetch_required_signing_keys( - parsed_pdus.iter().map(|(_event_id, event, _room_id)| event), - &pub_key_map, - ) - .await - .unwrap_or_else(|e| { - warn!( - "Could not fetch all signatures for PDUs from {}: {:?}", - sender_servername, e - ); - }); + // We go through all the signatures we see on the PDUs and fetch the + // corresponding signing keys + services() + .rooms + .event_handler + .fetch_required_signing_keys(parsed_pdus.iter().map(|(_event_id, event, _room_id)| event), &pub_key_map) + .await + .unwrap_or_else(|e| { + warn!("Could not fetch all signatures for PDUs from {}: {:?}", sender_servername, e); + }); - for (event_id, value, room_id) in parsed_pdus { - let mutex = Arc::clone( - services() - .globals - .roomid_mutex_federation - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let mutex_lock = mutex.lock().await; - let start_time = Instant::now(); - resolved_map.insert( - event_id.clone(), - services() - .rooms - .event_handler - .handle_incoming_pdu( - sender_servername, - &event_id, - &room_id, - value, - true, - &pub_key_map, - ) - .await - .map(|_| ()), - ); - drop(mutex_lock); + for (event_id, value, room_id) in parsed_pdus { + let mutex = + Arc::clone(services().globals.roomid_mutex_federation.write().unwrap().entry(room_id.clone()).or_default()); + let mutex_lock = mutex.lock().await; + let start_time = Instant::now(); + resolved_map.insert( + event_id.clone(), + services() + .rooms + .event_handler + .handle_incoming_pdu(sender_servername, &event_id, &room_id, value, true, &pub_key_map) + .await + .map(|_| ()), + ); + drop(mutex_lock); - let elapsed = start_time.elapsed(); - debug!( - "Handling transaction of event {} took {}m{}s", - event_id, - elapsed.as_secs() / 60, - elapsed.as_secs() % 60 - ); - } + let elapsed = start_time.elapsed(); + debug!( + "Handling transaction of event {} took {}m{}s", + event_id, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); + } - for pdu in &resolved_map { - if let Err(e) = pdu.1 { - if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) { - warn!("Incoming PDU failed {:?}", pdu); - } - } - } + for pdu in &resolved_map { + if let Err(e) = pdu.1 { + if matches!(e, Error::BadRequest(ErrorKind::NotFound, _)) { + warn!("Incoming PDU failed {:?}", pdu); + } + } + } - for edu in body - .edus - .iter() - .filter_map(|edu| serde_json::from_str::(edu.json().get()).ok()) - { - match edu { - Edu::Presence(presence) => { - if !services().globals.allow_incoming_presence() { - continue; - } + for edu in body.edus.iter().filter_map(|edu| serde_json::from_str::(edu.json().get()).ok()) { + match edu { + Edu::Presence(presence) => { + if !services().globals.allow_incoming_presence() { + continue; + } - for update in presence.push { - for room_id in services().rooms.state_cache.rooms_joined(&update.user_id) { - services().rooms.edus.presence.set_presence( - &room_id?, - &update.user_id, - update.presence.clone(), - Some(update.currently_active), - Some(update.last_active_ago), - update.status_msg.clone(), - )?; - } - } - } - Edu::Receipt(receipt) => { - for (room_id, room_updates) in receipt.receipts { - for (user_id, user_updates) in room_updates.read { - if let Some((event_id, _)) = user_updates - .event_ids - .iter() - .filter_map(|id| { - services() - .rooms - .timeline - .get_pdu_count(id) - .ok() - .flatten() - .map(|r| (id, r)) - }) - .max_by_key(|(_, count)| *count) - { - let mut user_receipts = BTreeMap::new(); - user_receipts.insert(user_id.clone(), user_updates.data); + for update in presence.push { + for room_id in services().rooms.state_cache.rooms_joined(&update.user_id) { + services().rooms.edus.presence.set_presence( + &room_id?, + &update.user_id, + update.presence.clone(), + Some(update.currently_active), + Some(update.last_active_ago), + update.status_msg.clone(), + )?; + } + } + }, + Edu::Receipt(receipt) => { + for (room_id, room_updates) in receipt.receipts { + for (user_id, user_updates) in room_updates.read { + if let Some((event_id, _)) = user_updates + .event_ids + .iter() + .filter_map(|id| { + services().rooms.timeline.get_pdu_count(id).ok().flatten().map(|r| (id, r)) + }) + .max_by_key(|(_, count)| *count) + { + let mut user_receipts = BTreeMap::new(); + user_receipts.insert(user_id.clone(), user_updates.data); - let mut receipts = BTreeMap::new(); - receipts.insert(ReceiptType::Read, user_receipts); + let mut receipts = BTreeMap::new(); + receipts.insert(ReceiptType::Read, user_receipts); - let mut receipt_content = BTreeMap::new(); - receipt_content.insert(event_id.to_owned(), receipts); + let mut receipt_content = BTreeMap::new(); + receipt_content.insert(event_id.to_owned(), receipts); - let event = ReceiptEvent { - content: ReceiptEventContent(receipt_content), - room_id: room_id.clone(), - }; - services() - .rooms - .edus - .read_receipt - .readreceipt_update(&user_id, &room_id, event)?; - } else { - // TODO fetch missing events - debug!("No known event ids in read receipt: {:?}", user_updates); - } - } - } - } - Edu::Typing(typing) => { - if services() - .rooms - .state_cache - .is_joined(&typing.user_id, &typing.room_id)? - { - if typing.typing { - services().rooms.edus.typing.typing_add( - &typing.user_id, - &typing.room_id, - 3000 + utils::millis_since_unix_epoch(), - )?; - } else { - services() - .rooms - .edus - .typing - .typing_remove(&typing.user_id, &typing.room_id)?; - } - } - } - Edu::DeviceListUpdate(DeviceListUpdateContent { user_id, .. }) => { - services().users.mark_device_key_update(&user_id)?; - } - Edu::DirectToDevice(DirectDeviceContent { - sender, - ev_type, - message_id, - messages, - }) => { - // Check if this is a new transaction id - if services() - .transaction_ids - .existing_txnid(&sender, None, &message_id)? - .is_some() - { - continue; - } + let event = ReceiptEvent { + content: ReceiptEventContent(receipt_content), + room_id: room_id.clone(), + }; + services().rooms.edus.read_receipt.readreceipt_update(&user_id, &room_id, event)?; + } else { + // TODO fetch missing events + debug!("No known event ids in read receipt: {:?}", user_updates); + } + } + } + }, + Edu::Typing(typing) => { + if services().rooms.state_cache.is_joined(&typing.user_id, &typing.room_id)? { + if typing.typing { + services().rooms.edus.typing.typing_add( + &typing.user_id, + &typing.room_id, + 3000 + utils::millis_since_unix_epoch(), + )?; + } else { + services().rooms.edus.typing.typing_remove(&typing.user_id, &typing.room_id)?; + } + } + }, + Edu::DeviceListUpdate(DeviceListUpdateContent { + user_id, + .. + }) => { + services().users.mark_device_key_update(&user_id)?; + }, + Edu::DirectToDevice(DirectDeviceContent { + sender, + ev_type, + message_id, + messages, + }) => { + // Check if this is a new transaction id + if services().transaction_ids.existing_txnid(&sender, None, &message_id)?.is_some() { + continue; + } - for (target_user_id, map) in &messages { - for (target_device_id_maybe, event) in map { - match target_device_id_maybe { - DeviceIdOrAllDevices::DeviceId(target_device_id) => { - services().users.add_to_device_event( - &sender, - target_user_id, - target_device_id, - &ev_type.to_string(), - event.deserialize_as().map_err(|e| { - warn!("To-Device event is invalid: {event:?} {e}"); - Error::BadRequest( - ErrorKind::InvalidParam, - "Event is invalid", - ) - })?, - )?; - } + for (target_user_id, map) in &messages { + for (target_device_id_maybe, event) in map { + match target_device_id_maybe { + DeviceIdOrAllDevices::DeviceId(target_device_id) => { + services().users.add_to_device_event( + &sender, + target_user_id, + target_device_id, + &ev_type.to_string(), + event.deserialize_as().map_err(|e| { + warn!("To-Device event is invalid: {event:?} {e}"); + Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") + })?, + )?; + }, - DeviceIdOrAllDevices::AllDevices => { - for target_device_id in - services().users.all_device_ids(target_user_id) - { - services().users.add_to_device_event( - &sender, - target_user_id, - &target_device_id?, - &ev_type.to_string(), - event.deserialize_as().map_err(|_| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Event is invalid", - ) - })?, - )?; - } - } - } - } - } + DeviceIdOrAllDevices::AllDevices => { + for target_device_id in services().users.all_device_ids(target_user_id) { + services().users.add_to_device_event( + &sender, + target_user_id, + &target_device_id?, + &ev_type.to_string(), + event.deserialize_as().map_err(|_| { + Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid") + })?, + )?; + } + }, + } + } + } - // Save transaction id with empty data - services() - .transaction_ids - .add_txnid(&sender, None, &message_id, &[])?; - } - Edu::SigningKeyUpdate(SigningKeyUpdateContent { - user_id, - master_key, - self_signing_key, - }) => { - if user_id.server_name() != sender_servername { - continue; - } - if let Some(master_key) = master_key { - services().users.add_cross_signing_keys( - &user_id, - &master_key, - &self_signing_key, - &None, - true, - )?; - } - } - Edu::_Custom(_) => {} - } - } + // Save transaction id with empty data + services().transaction_ids.add_txnid(&sender, None, &message_id, &[])?; + }, + Edu::SigningKeyUpdate(SigningKeyUpdateContent { + user_id, + master_key, + self_signing_key, + }) => { + if user_id.server_name() != sender_servername { + continue; + } + if let Some(master_key) = master_key { + services().users.add_cross_signing_keys(&user_id, &master_key, &self_signing_key, &None, true)?; + } + }, + Edu::_Custom(_) => {}, + } + } - Ok(send_transaction_message::v1::Response { - pdus: resolved_map - .into_iter() - .map(|(e, r)| (e, r.map_err(|e| e.sanitized_error()))) - .collect(), - }) + Ok(send_transaction_message::v1::Response { + pdus: resolved_map.into_iter().map(|(e, r)| (e, r.map_err(|e| e.sanitized_error()))).collect(), + }) } /// # `GET /_matrix/federation/v1/event/{eventId}` /// /// Retrieves a single event from the server. /// -/// - Only works if a user of this server is currently invited or joined the room -pub async fn get_event_route( - body: Ruma, -) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } +/// - Only works if a user of this server is currently invited or joined the +/// room +pub async fn get_event_route(body: Ruma) -> Result { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - let event = services() - .rooms - .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| { - warn!("Event not found, event ID: {:?}", &body.event_id); - Error::BadRequest(ErrorKind::NotFound, "Event not found.") - })?; + let event = services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else(|| { + warn!("Event not found, event ID: {:?}", &body.event_id); + Error::BadRequest(ErrorKind::NotFound, "Event not found.") + })?; - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - let room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + let room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - if !services() - .rooms - .state_cache - .server_in_room(sender_servername, room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room", - )); - } + if !services().rooms.state_cache.server_in_room(sender_servername, room_id)? { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room")); + } - if !services().rooms.state_accessor.server_can_see_event( - sender_servername, - room_id, - &body.event_id, - )? { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not allowed to see event.", - )); - } + if !services().rooms.state_accessor.server_can_see_event(sender_servername, room_id, &body.event_id)? { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not allowed to see event.")); + } - Ok(get_event::v1::Response { - origin: services().globals.server_name().to_owned(), - origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdu: PduEvent::convert_to_outgoing_federation_event(event), - }) + Ok(get_event::v1::Response { + origin: services().globals.server_name().to_owned(), + origin_server_ts: MilliSecondsSinceUnixEpoch::now(), + pdu: PduEvent::convert_to_outgoing_federation_event(event), + }) } /// # `GET /_matrix/federation/v1/backfill/` /// /// Retrieves events from before the sender joined the room, if the room's /// history visibility allows. -pub async fn get_backfill_route( - body: Ruma, -) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } +pub async fn get_backfill_route(body: Ruma) -> Result { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - debug!("Got backfill request from: {}", sender_servername); + debug!("Got backfill request from: {}", sender_servername); - if !services() - .rooms - .state_cache - .server_in_room(sender_servername, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room.", - )); - } + if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); + } - services() - .rooms - .event_handler - .acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; - let until = body - .v - .iter() - .map(|eventid| services().rooms.timeline.get_pdu_count(eventid)) - .filter_map(|r| r.ok().flatten()) - .max() - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "No known eventid in v", - ))?; + let until = body + .v + .iter() + .map(|eventid| services().rooms.timeline.get_pdu_count(eventid)) + .filter_map(|r| r.ok().flatten()) + .max() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "No known eventid in v"))?; - let limit = body.limit.min(uint!(100)); + let limit = body.limit.min(uint!(100)); - let all_events = services() - .rooms - .timeline - .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)? - .take(limit.try_into().unwrap()); + let all_events = services() + .rooms + .timeline + .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until)? + .take(limit.try_into().unwrap()); - let events = all_events - .filter_map(std::result::Result::ok) - .filter(|(_, e)| { - matches!( - services().rooms.state_accessor.server_can_see_event( - sender_servername, - &e.room_id, - &e.event_id, - ), - Ok(true), - ) - }) - .map(|(_, pdu)| services().rooms.timeline.get_pdu_json(&pdu.event_id)) - .filter_map(|r| r.ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(); + let events = all_events + .filter_map(std::result::Result::ok) + .filter(|(_, e)| { + matches!( + services().rooms.state_accessor.server_can_see_event(sender_servername, &e.room_id, &e.event_id,), + Ok(true), + ) + }) + .map(|(_, pdu)| services().rooms.timeline.get_pdu_json(&pdu.event_id)) + .filter_map(|r| r.ok().flatten()) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(); - Ok(get_backfill::v1::Response { - origin: services().globals.server_name().to_owned(), - origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - pdus: events, - }) + Ok(get_backfill::v1::Response { + origin: services().globals.server_name().to_owned(), + origin_server_ts: MilliSecondsSinceUnixEpoch::now(), + pdus: events, + }) } /// # `POST /_matrix/federation/v1/get_missing_events/{roomId}` /// /// Retrieves events that the sender is missing. pub async fn get_missing_events_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - if !services() - .rooms - .state_cache - .server_in_room(sender_servername, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room", - )); - } + if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room")); + } - services() - .rooms - .event_handler - .acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; - let mut queued_events = body.latest_events.clone(); - let mut events = Vec::new(); + let mut queued_events = body.latest_events.clone(); + let mut events = Vec::new(); - let mut i = 0; - while i < queued_events.len() && events.len() < u64::from(body.limit) as usize { - if let Some(pdu) = services().rooms.timeline.get_pdu_json(&queued_events[i])? { - let room_id_str = pdu - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + let mut i = 0; + while i < queued_events.len() && events.len() < u64::from(body.limit) as usize { + if let Some(pdu) = services().rooms.timeline.get_pdu_json(&queued_events[i])? { + let room_id_str = pdu + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - let event_room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + let event_room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - if event_room_id != body.room_id { - warn!( - "Evil event detected: Event {} found while searching in room {}", - queued_events[i], body.room_id - ); - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Evil event detected", - )); - } + if event_room_id != body.room_id { + warn!( + "Evil event detected: Event {} found while searching in room {}", + queued_events[i], body.room_id + ); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Evil event detected")); + } - if body.earliest_events.contains(&queued_events[i]) { - i += 1; - continue; - } + if body.earliest_events.contains(&queued_events[i]) { + i += 1; + continue; + } - if !services().rooms.state_accessor.server_can_see_event( - sender_servername, - &body.room_id, - &queued_events[i], - )? { - i += 1; - continue; - } + if !services().rooms.state_accessor.server_can_see_event( + sender_servername, + &body.room_id, + &queued_events[i], + )? { + i += 1; + continue; + } - queued_events.extend_from_slice( - &serde_json::from_value::>( - serde_json::to_value(pdu.get("prev_events").cloned().ok_or_else(|| { - Error::bad_database("Event in db has no prev_events field.") - })?) - .expect("canonical json is valid json value"), - ) - .map_err(|_| Error::bad_database("Invalid prev_events content in pdu in db."))?, - ); - events.push(PduEvent::convert_to_outgoing_federation_event(pdu)); - } - i += 1; - } + queued_events.extend_from_slice( + &serde_json::from_value::>( + serde_json::to_value( + pdu.get("prev_events") + .cloned() + .ok_or_else(|| Error::bad_database("Event in db has no prev_events field."))?, + ) + .expect("canonical json is valid json value"), + ) + .map_err(|_| Error::bad_database("Invalid prev_events content in pdu in db."))?, + ); + events.push(PduEvent::convert_to_outgoing_federation_event(pdu)); + } + i += 1; + } - Ok(get_missing_events::v1::Response { events }) + Ok(get_missing_events::v1::Response { + events, + }) } /// # `GET /_matrix/federation/v1/event_auth/{roomId}/{eventId}` @@ -1330,896 +1114,729 @@ pub async fn get_missing_events_route( /// /// - This does not include the event itself pub async fn get_event_authorization_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - if !services() - .rooms - .state_cache - .server_in_room(sender_servername, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room.", - )); - } + if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); + } - services() - .rooms - .event_handler - .acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; - let event = services() - .rooms - .timeline - .get_pdu_json(&body.event_id)? - .ok_or_else(|| { - warn!("Event not found, event ID: {:?}", &body.event_id); - Error::BadRequest(ErrorKind::NotFound, "Event not found.") - })?; + let event = services().rooms.timeline.get_pdu_json(&body.event_id)?.ok_or_else(|| { + warn!("Event not found, event ID: {:?}", &body.event_id); + Error::BadRequest(ErrorKind::NotFound, "Event not found.") + })?; - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - let room_id = <&RoomId>::try_from(room_id_str) - .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + let room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; - let auth_chain_ids = services() - .rooms - .auth_chain - .get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]) - .await?; + let auth_chain_ids = services().rooms.auth_chain.get_auth_chain(room_id, vec![Arc::from(&*body.event_id)]).await?; - Ok(get_event_authorization::v1::Response { - auth_chain: auth_chain_ids - .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok()?) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(), - }) + Ok(get_event_authorization::v1::Response { + auth_chain: auth_chain_ids + .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok()?) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(), + }) } /// # `GET /_matrix/federation/v1/state/{roomId}` /// /// Retrieves the current state of the room. -pub async fn get_room_state_route( - body: Ruma, -) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } +pub async fn get_room_state_route(body: Ruma) -> Result { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - if !services() - .rooms - .state_cache - .server_in_room(sender_servername, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room.", - )); - } + if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); + } - services() - .rooms - .event_handler - .acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; - let shortstatehash = services() - .rooms - .state_accessor - .pdu_shortstatehash(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Pdu state not found.", - ))?; + let shortstatehash = services() + .rooms + .state_accessor + .pdu_shortstatehash(&body.event_id)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; - let pdus = services() - .rooms - .state_accessor - .state_full_ids(shortstatehash) - .await? - .into_values() - .map(|id| { - PduEvent::convert_to_outgoing_federation_event( - services() - .rooms - .timeline - .get_pdu_json(&id) - .unwrap() - .unwrap(), - ) - }) - .collect(); + let pdus = services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await? + .into_values() + .map(|id| { + PduEvent::convert_to_outgoing_federation_event( + services().rooms.timeline.get_pdu_json(&id).unwrap().unwrap(), + ) + }) + .collect(); - let auth_chain_ids = services() - .rooms - .auth_chain - .get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]) - .await?; + let auth_chain_ids = + services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; - Ok(get_room_state::v1::Response { - auth_chain: auth_chain_ids - .filter_map( - |id| match services().rooms.timeline.get_pdu_json(&id).ok()? { - Some(json) => Some(PduEvent::convert_to_outgoing_federation_event(json)), - None => { - error!("Could not find event json for {id} in db."); - None - } - }, - ) - .collect(), - pdus, - }) + Ok(get_room_state::v1::Response { + auth_chain: auth_chain_ids + .filter_map(|id| match services().rooms.timeline.get_pdu_json(&id).ok()? { + Some(json) => Some(PduEvent::convert_to_outgoing_federation_event(json)), + None => { + error!("Could not find event json for {id} in db."); + None + }, + }) + .collect(), + pdus, + }) } /// # `GET /_matrix/federation/v1/state_ids/{roomId}` /// /// Retrieves the current state of the room. pub async fn get_room_state_ids_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - if !services() - .rooms - .state_cache - .server_in_room(sender_servername, &body.room_id)? - { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server is not in room.", - )); - } + if !services().rooms.state_cache.server_in_room(sender_servername, &body.room_id)? { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Server is not in room.")); + } - services() - .rooms - .event_handler - .acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; - let shortstatehash = services() - .rooms - .state_accessor - .pdu_shortstatehash(&body.event_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Pdu state not found.", - ))?; + let shortstatehash = services() + .rooms + .state_accessor + .pdu_shortstatehash(&body.event_id)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; - let pdu_ids = services() - .rooms - .state_accessor - .state_full_ids(shortstatehash) - .await? - .into_values() - .map(|id| (*id).to_owned()) - .collect(); + let pdu_ids = services() + .rooms + .state_accessor + .state_full_ids(shortstatehash) + .await? + .into_values() + .map(|id| (*id).to_owned()) + .collect(); - let auth_chain_ids = services() - .rooms - .auth_chain - .get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]) - .await?; + let auth_chain_ids = + services().rooms.auth_chain.get_auth_chain(&body.room_id, vec![Arc::from(&*body.event_id)]).await?; - Ok(get_room_state_ids::v1::Response { - auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), - pdu_ids, - }) + Ok(get_room_state_ids::v1::Response { + auth_chain_ids: auth_chain_ids.map(|id| (*id).to_owned()).collect(), + pdu_ids, + }) } /// # `GET /_matrix/federation/v1/make_join/{roomId}/{userId}` /// /// Creates a join template. pub async fn create_join_event_template_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - if !services().rooms.metadata.exists(&body.room_id)? { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Room is unknown to this server.", - )); - } + if !services().rooms.metadata.exists(&body.room_id)? { + return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); + } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - services() - .rooms - .event_handler - .acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(body.room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; - // TODO: Conduit does not implement restricted join rules yet, we always reject - let join_rules_event = services().rooms.state_accessor.room_state_get( - &body.room_id, - &StateEventType::RoomJoinRules, - "", - )?; + // TODO: Conduit does not implement restricted join rules yet, we always reject + let join_rules_event = + services().rooms.state_accessor.room_state_get(&body.room_id, &StateEventType::RoomJoinRules, "")?; - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; + let join_rules_event_content: Option = join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }) + }) + .transpose()?; - if let Some(join_rules_event_content) = join_rules_event_content { - if matches!( - join_rules_event_content.join_rule, - JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. } - ) { - return Err(Error::BadRequest( - ErrorKind::UnableToAuthorizeJoin, - "Conduit does not support restricted rooms yet.", - )); - } - } + if let Some(join_rules_event_content) = join_rules_event_content { + if matches!( + join_rules_event_content.join_rule, + JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. } + ) { + return Err(Error::BadRequest( + ErrorKind::UnableToAuthorizeJoin, + "Conduit does not support restricted rooms yet.", + )); + } + } - let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; - if !body.ver.contains(&room_version_id) { - return Err(Error::BadRequest( - ErrorKind::IncompatibleRoomVersion { - room_version: room_version_id, - }, - "Room version not supported.", - )); - } + let room_version_id = services().rooms.state.get_room_version(&body.room_id)?; + if !body.ver.contains(&room_version_id) { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: room_version_id, + }, + "Room version not supported.", + )); + } - let content = to_raw_value(&RoomMemberEventContent { - avatar_url: None, - blurhash: None, - displayname: None, - is_direct: None, - membership: MembershipState::Join, - third_party_invite: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("member event is valid value"); + let content = to_raw_value(&RoomMemberEventContent { + avatar_url: None, + blurhash: None, + displayname: None, + is_direct: None, + membership: MembershipState::Join, + third_party_invite: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("member event is valid value"); - let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content, - unsigned: None, - state_key: Some(body.user_id.to_string()), - redacts: None, - }, - &body.user_id, - &body.room_id, - &state_lock, - )?; + let (_pdu, mut pdu_json) = services().rooms.timeline.create_hash_and_sign_event( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content, + unsigned: None, + state_key: Some(body.user_id.to_string()), + redacts: None, + }, + &body.user_id, + &body.room_id, + &state_lock, + )?; - drop(state_lock); + drop(state_lock); - pdu_json.remove("event_id"); + pdu_json.remove("event_id"); - Ok(prepare_join_event::v1::Response { - room_version: Some(room_version_id), - event: to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"), - }) + Ok(prepare_join_event::v1::Response { + room_version: Some(room_version_id), + event: to_raw_value(&pdu_json).expect("CanonicalJson can be serialized to JSON"), + }) } async fn create_join_event( - sender_servername: &ServerName, - room_id: &RoomId, - pdu: &RawJsonValue, + sender_servername: &ServerName, room_id: &RoomId, pdu: &RawJsonValue, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - if !services().rooms.metadata.exists(room_id)? { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Room is unknown to this server.", - )); - } + if !services().rooms.metadata.exists(room_id)? { + return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server.")); + } - services() - .rooms - .event_handler - .acl_check(sender_servername, room_id)?; + services().rooms.event_handler.acl_check(sender_servername, room_id)?; - // TODO: Conduit does not implement restricted join rules yet, we always reject - let join_rules_event = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomJoinRules, - "", - )?; + // TODO: Conduit does not implement restricted join rules yet, we always reject + let join_rules_event = + services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?; - let join_rules_event_content: Option = join_rules_event - .as_ref() - .map(|join_rules_event| { - serde_json::from_str(join_rules_event.content.get()).map_err(|e| { - warn!("Invalid join rules event: {}", e); - Error::bad_database("Invalid join rules event in db.") - }) - }) - .transpose()?; + let join_rules_event_content: Option = join_rules_event + .as_ref() + .map(|join_rules_event| { + serde_json::from_str(join_rules_event.content.get()).map_err(|e| { + warn!("Invalid join rules event: {}", e); + Error::bad_database("Invalid join rules event in db.") + }) + }) + .transpose()?; - if let Some(join_rules_event_content) = join_rules_event_content { - if matches!( - join_rules_event_content.join_rule, - JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. } - ) { - return Err(Error::BadRequest( - ErrorKind::UnableToAuthorizeJoin, - "Conduit does not support restricted rooms yet.", - )); - } - } + if let Some(join_rules_event_content) = join_rules_event_content { + if matches!( + join_rules_event_content.join_rule, + JoinRule::Restricted { .. } | JoinRule::KnockRestricted { .. } + ) { + return Err(Error::BadRequest( + ErrorKind::UnableToAuthorizeJoin, + "Conduit does not support restricted rooms yet.", + )); + } + } - // We need to return the state prior to joining, let's keep a reference to that here - let shortstatehash = services() - .rooms - .state - .get_room_shortstatehash(room_id)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Pdu state not found.", - ))?; + // We need to return the state prior to joining, let's keep a reference to that + // here + let shortstatehash = services() + .rooms + .state + .get_room_shortstatehash(room_id)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Pdu state not found."))?; - let pub_key_map = RwLock::new(BTreeMap::new()); - // let mut auth_cache = EventMap::new(); + let pub_key_map = RwLock::new(BTreeMap::new()); + // let mut auth_cache = EventMap::new(); - // We do not add the event_id field to the pdu here because of signature and hashes checks - let room_version_id = services().rooms.state.get_room_version(room_id)?; - let (event_id, value) = match gen_event_id_canonical_json(pdu, &room_version_id) { - Ok(t) => t, - Err(_) => { - // Event could not be converted to canonical json - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not convert event to canonical json.", - )); - } - }; + // We do not add the event_id field to the pdu here because of signature and + // hashes checks + let room_version_id = services().rooms.state.get_room_version(room_id)?; + let (event_id, value) = match gen_event_id_canonical_json(pdu, &room_version_id) { + Ok(t) => t, + Err(_) => { + // Event could not be converted to canonical json + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not convert event to canonical json.", + )); + }, + }; - let origin: OwnedServerName = serde_json::from_value( - serde_json::to_value(value.get("origin").ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event needs an origin field.", - ))?) - .expect("CanonicalJson is valid json value"), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; + let origin: OwnedServerName = serde_json::from_value( + serde_json::to_value( + value.get("origin").ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event needs an origin field."))?, + ) + .expect("CanonicalJson is valid json value"), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - services() - .rooms - .event_handler - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; + services().rooms.event_handler.fetch_required_signing_keys([&value], &pub_key_map).await?; - let mutex = Arc::clone( - services() - .globals - .roomid_mutex_federation - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let mutex_lock = mutex.lock().await; - let pdu_id: Vec = services() - .rooms - .event_handler - .handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) - .await? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; - drop(mutex_lock); + let mutex = + Arc::clone(services().globals.roomid_mutex_federation.write().unwrap().entry(room_id.to_owned()).or_default()); + let mutex_lock = mutex.lock().await; + let pdu_id: Vec = services() + .rooms + .event_handler + .handle_incoming_pdu(&origin, &event_id, room_id, value, true, &pub_key_map) + .await? + .ok_or(Error::BadRequest( + ErrorKind::InvalidParam, + "Could not accept incoming PDU as timeline event.", + ))?; + drop(mutex_lock); - let state_ids = services() - .rooms - .state_accessor - .state_full_ids(shortstatehash) - .await?; - let auth_chain_ids = services() - .rooms - .auth_chain - .get_auth_chain(room_id, state_ids.values().cloned().collect()) - .await?; + let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?; + let auth_chain_ids = + services().rooms.auth_chain.get_auth_chain(room_id, state_ids.values().cloned().collect()).await?; - let servers = services() - .rooms - .state_cache - .room_servers(room_id) - .filter_map(std::result::Result::ok) - .filter(|server| &**server != services().globals.server_name()); + let servers = services() + .rooms + .state_cache + .room_servers(room_id) + .filter_map(std::result::Result::ok) + .filter(|server| &**server != services().globals.server_name()); - services().sending.send_pdu(servers, &pdu_id)?; + services().sending.send_pdu(servers, &pdu_id)?; - Ok(create_join_event::v1::RoomState { - auth_chain: auth_chain_ids - .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(), - state: state_ids - .iter() - .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) - .map(PduEvent::convert_to_outgoing_federation_event) - .collect(), - event: None, // TODO: handle restricted joins - }) + Ok(create_join_event::v1::RoomState { + auth_chain: auth_chain_ids + .filter_map(|id| services().rooms.timeline.get_pdu_json(&id).ok().flatten()) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(), + state: state_ids + .iter() + .filter_map(|(_, id)| services().rooms.timeline.get_pdu_json(id).ok().flatten()) + .map(PduEvent::convert_to_outgoing_federation_event) + .collect(), + event: None, // TODO: handle restricted joins + }) } /// # `PUT /_matrix/federation/v1/send_join/{roomId}/{eventId}` /// /// Submits a signed join event. pub async fn create_join_event_v1_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; + let room_state = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; - Ok(create_join_event::v1::Response { room_state }) + Ok(create_join_event::v1::Response { + room_state, + }) } /// # `PUT /_matrix/federation/v2/send_join/{roomId}/{eventId}` /// /// Submits a signed join event. pub async fn create_join_event_v2_route( - body: Ruma, + body: Ruma, ) -> Result { - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - let create_join_event::v1::RoomState { - auth_chain, - state, - event, - } = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; - let room_state = create_join_event::v2::RoomState { - members_omitted: false, - auth_chain, - state, - event, - servers_in_room: None, - }; + let create_join_event::v1::RoomState { + auth_chain, + state, + event, + } = create_join_event(sender_servername, &body.room_id, &body.pdu).await?; + let room_state = create_join_event::v2::RoomState { + members_omitted: false, + auth_chain, + state, + event, + servers_in_room: None, + }; - Ok(create_join_event::v2::Response { room_state }) + Ok(create_join_event::v2::Response { + room_state, + }) } /// # `PUT /_matrix/federation/v2/invite/{roomId}/{eventId}` /// /// Invites a remote user to a room. -pub async fn create_invite_route( - body: Ruma, -) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } +pub async fn create_invite_route(body: Ruma) -> Result { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - services() - .rooms - .event_handler - .acl_check(sender_servername, &body.room_id)?; + services().rooms.event_handler.acl_check(sender_servername, &body.room_id)?; - if !services() - .globals - .supported_room_versions() - .contains(&body.room_version) - { - return Err(Error::BadRequest( - ErrorKind::IncompatibleRoomVersion { - room_version: body.room_version.clone(), - }, - "Server does not support this room version.", - )); - } + if !services().globals.supported_room_versions().contains(&body.room_version) { + return Err(Error::BadRequest( + ErrorKind::IncompatibleRoomVersion { + room_version: body.room_version.clone(), + }, + "Server does not support this room version.", + )); + } - let mut signed_event = utils::to_canonical_object(&body.event).map_err(|e| { - error!("Failed to convert invite event to canonical JSON: {}", e); - Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid.") - })?; + let mut signed_event = utils::to_canonical_object(&body.event).map_err(|e| { + error!("Failed to convert invite event to canonical JSON: {}", e); + Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid.") + })?; - ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut signed_event, - &body.room_version, - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; + ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut signed_event, + &body.room_version, + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Failed to sign event."))?; - // Generate event id - let event_id = EventId::parse(format!( - "${}", - ruma::signatures::reference_hash(&signed_event, &body.room_version) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); + // Generate event id + let event_id = EventId::parse(format!( + "${}", + ruma::signatures::reference_hash(&signed_event, &body.room_version) + .expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); - // Add event_id back - signed_event.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.to_string()), - ); + // Add event_id back + signed_event.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.to_string())); - let sender: OwnedUserId = serde_json::from_value( - signed_event - .get("sender") - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event had no sender field.", - ))? - .clone() - .into(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user id."))?; + let sender: OwnedUserId = serde_json::from_value( + signed_event + .get("sender") + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event had no sender field."))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "sender is not a user id."))?; - let invited_user: Box<_> = serde_json::from_value( - signed_event - .get("state_key") - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Event had no state_key field.", - ))? - .clone() - .into(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user id."))?; + let invited_user: Box<_> = serde_json::from_value( + signed_event + .get("state_key") + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event had no state_key field."))? + .clone() + .into(), + ) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "state_key is not a user id."))?; - if services().rooms.metadata.is_banned(&body.room_id)? - && !services().users.is_admin(&invited_user)? - { - info!( - "Received remote invite from server {} for room {} and for user {invited_user}, but room is banned by us.", - &sender_servername, &body.room_id - ); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "This room is banned on this homeserver.", - )); - } + if services().rooms.metadata.is_banned(&body.room_id)? && !services().users.is_admin(&invited_user)? { + info!( + "Received remote invite from server {} for room {} and for user {invited_user}, but room is banned by us.", + &sender_servername, &body.room_id + ); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "This room is banned on this homeserver.", + )); + } - if services().globals.block_non_admin_invites() && !services().users.is_admin(&invited_user)? { - info!("Received remote invite from server {} for room {} and for user {invited_user} who is not an admin, but \"block_non_admin_invites\" is enabled, rejecting.", &sender_servername, &body.room_id); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "This server does not allow room invites.", - )); - } + if services().globals.block_non_admin_invites() && !services().users.is_admin(&invited_user)? { + info!( + "Received remote invite from server {} for room {} and for user {invited_user} who is not an admin, but \ + \"block_non_admin_invites\" is enabled, rejecting.", + &sender_servername, &body.room_id + ); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "This server does not allow room invites.", + )); + } - let mut invite_state = body.invite_room_state.clone(); + let mut invite_state = body.invite_room_state.clone(); - let mut event: JsonObject = serde_json::from_str(body.event.get()) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; + let mut event: JsonObject = serde_json::from_str(body.event.get()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event bytes."))?; - event.insert("event_id".to_owned(), "$dummy".into()); + event.insert("event_id".to_owned(), "$dummy".into()); - let pdu: PduEvent = serde_json::from_value(event.into()).map_err(|e| { - warn!("Invalid invite event: {}", e); - Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event.") - })?; + let pdu: PduEvent = serde_json::from_value(event.into()).map_err(|e| { + warn!("Invalid invite event: {}", e); + Error::BadRequest(ErrorKind::InvalidParam, "Invalid invite event.") + })?; - invite_state.push(pdu.to_stripped_state_event()); + invite_state.push(pdu.to_stripped_state_event()); - // If we are active in the room, the remote server will notify us about the join via /send - if !services() - .rooms - .state_cache - .server_in_room(services().globals.server_name(), &body.room_id)? - { - services() - .rooms - .state_cache - .update_membership( - &body.room_id, - &invited_user, - RoomMemberEventContent::new(MembershipState::Invite), - &sender, - Some(invite_state), - true, - ) - .await?; - } + // If we are active in the room, the remote server will notify us about the join + // via /send + if !services().rooms.state_cache.server_in_room(services().globals.server_name(), &body.room_id)? { + services() + .rooms + .state_cache + .update_membership( + &body.room_id, + &invited_user, + RoomMemberEventContent::new(MembershipState::Invite), + &sender, + Some(invite_state), + true, + ) + .await?; + } - Ok(create_invite::v2::Response { - event: PduEvent::convert_to_outgoing_federation_event(signed_event), - }) + Ok(create_invite::v2::Response { + event: PduEvent::convert_to_outgoing_federation_event(signed_event), + }) } /// # `GET /_matrix/federation/v1/user/devices/{userId}` /// /// Gets information on all devices of the user. -pub async fn get_devices_route( - body: Ruma, -) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } +pub async fn get_devices_route(body: Ruma) -> Result { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - if body.user_id.server_name() != services().globals.server_name() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to access user from other server.", - )); - } + if body.user_id.server_name() != services().globals.server_name() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Tried to access user from other server.", + )); + } - let sender_servername = body - .sender_servername - .as_ref() - .expect("server is authenticated"); + let sender_servername = body.sender_servername.as_ref().expect("server is authenticated"); - Ok(get_devices::v1::Response { - user_id: body.user_id.clone(), - stream_id: services() - .users - .get_devicelist_version(&body.user_id)? - .unwrap_or(0) - .try_into() - .expect("version will not grow that large"), - devices: services() - .users - .all_devices_metadata(&body.user_id) - .filter_map(std::result::Result::ok) - .filter_map(|metadata| { - let device_id_string = metadata.device_id.as_str().to_owned(); - let device_display_name = match services().globals.allow_device_name_federation() { - true => metadata.display_name, - false => Some(device_id_string), - }; - Some(UserDevice { - keys: services() - .users - .get_device_keys(&body.user_id, &metadata.device_id) - .ok()??, - device_id: metadata.device_id, - device_display_name, - }) - }) - .collect(), - master_key: services().users.get_master_key(None, &body.user_id, &|u| { - u.server_name() == sender_servername - })?, - self_signing_key: services() - .users - .get_self_signing_key(None, &body.user_id, &|u| { - u.server_name() == sender_servername - })?, - }) + Ok(get_devices::v1::Response { + user_id: body.user_id.clone(), + stream_id: services() + .users + .get_devicelist_version(&body.user_id)? + .unwrap_or(0) + .try_into() + .expect("version will not grow that large"), + devices: services() + .users + .all_devices_metadata(&body.user_id) + .filter_map(std::result::Result::ok) + .filter_map(|metadata| { + let device_id_string = metadata.device_id.as_str().to_owned(); + let device_display_name = match services().globals.allow_device_name_federation() { + true => metadata.display_name, + false => Some(device_id_string), + }; + Some(UserDevice { + keys: services().users.get_device_keys(&body.user_id, &metadata.device_id).ok()??, + device_id: metadata.device_id, + device_display_name, + }) + }) + .collect(), + master_key: services().users.get_master_key(None, &body.user_id, &|u| u.server_name() == sender_servername)?, + self_signing_key: services() + .users + .get_self_signing_key(None, &body.user_id, &|u| u.server_name() == sender_servername)?, + }) } /// # `GET /_matrix/federation/v1/query/directory` /// /// Resolve a room alias to a room id. pub async fn get_room_information_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - let room_id = services() - .rooms - .alias - .resolve_local_alias(&body.room_alias)? - .ok_or(Error::BadRequest( - ErrorKind::NotFound, - "Room alias not found.", - ))?; + let room_id = services() + .rooms + .alias + .resolve_local_alias(&body.room_alias)? + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Room alias not found."))?; - Ok(get_room_information::v1::Response { - room_id, - servers: vec![services().globals.server_name().to_owned()], - }) + Ok(get_room_information::v1::Response { + room_id, + servers: vec![services().globals.server_name().to_owned()], + }) } /// # `GET /_matrix/federation/v1/query/profile` /// /// Gets information on a profile. pub async fn get_profile_information_route( - body: Ruma, + body: Ruma, ) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - if body.user_id.server_name() != services().globals.server_name() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User does not belong to this server", - )); - } + if body.user_id.server_name() != services().globals.server_name() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "User does not belong to this server", + )); + } - let mut displayname = None; - let mut avatar_url = None; - let mut blurhash = None; + let mut displayname = None; + let mut avatar_url = None; + let mut blurhash = None; - match &body.field { - Some(ProfileField::DisplayName) => { - displayname = services().users.displayname(&body.user_id)?; - } - Some(ProfileField::AvatarUrl) => { - avatar_url = services().users.avatar_url(&body.user_id)?; - blurhash = services().users.blurhash(&body.user_id)?; - } - // TODO: what to do with custom - Some(_) => {} - None => { - displayname = services().users.displayname(&body.user_id)?; - avatar_url = services().users.avatar_url(&body.user_id)?; - blurhash = services().users.blurhash(&body.user_id)?; - } - } + match &body.field { + Some(ProfileField::DisplayName) => { + displayname = services().users.displayname(&body.user_id)?; + }, + Some(ProfileField::AvatarUrl) => { + avatar_url = services().users.avatar_url(&body.user_id)?; + blurhash = services().users.blurhash(&body.user_id)?; + }, + // TODO: what to do with custom + Some(_) => {}, + None => { + displayname = services().users.displayname(&body.user_id)?; + avatar_url = services().users.avatar_url(&body.user_id)?; + blurhash = services().users.blurhash(&body.user_id)?; + }, + } - Ok(get_profile_information::v1::Response { - blurhash, - displayname, - avatar_url, - }) + Ok(get_profile_information::v1::Response { + blurhash, + displayname, + avatar_url, + }) } /// # `POST /_matrix/federation/v1/user/keys/query` /// /// Gets devices and identity keys for the given users. pub async fn get_keys_route(body: Ruma) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - if body - .device_keys - .iter() - .any(|(u, _)| u.server_name() != services().globals.server_name()) - { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User does not belong to this server.", - )); - } + if body.device_keys.iter().any(|(u, _)| u.server_name() != services().globals.server_name()) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "User does not belong to this server.", + )); + } - let result = get_keys_helper( - None, - &body.device_keys, - |u| Some(u.server_name()) == body.sender_servername.as_deref(), - services().globals.allow_device_name_federation(), - ) - .await?; + let result = get_keys_helper( + None, + &body.device_keys, + |u| Some(u.server_name()) == body.sender_servername.as_deref(), + services().globals.allow_device_name_federation(), + ) + .await?; - Ok(get_keys::v1::Response { - device_keys: result.device_keys, - master_keys: result.master_keys, - self_signing_keys: result.self_signing_keys, - }) + Ok(get_keys::v1::Response { + device_keys: result.device_keys, + master_keys: result.master_keys, + self_signing_keys: result.self_signing_keys, + }) } /// # `POST /_matrix/federation/v1/user/keys/claim` /// /// Claims one-time keys. -pub async fn claim_keys_route( - body: Ruma, -) -> Result { - if !services().globals.allow_federation() { - return Err(Error::bad_config("Federation is disabled.")); - } +pub async fn claim_keys_route(body: Ruma) -> Result { + if !services().globals.allow_federation() { + return Err(Error::bad_config("Federation is disabled.")); + } - if body - .one_time_keys - .iter() - .any(|(u, _)| u.server_name() != services().globals.server_name()) - { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to access user from other server.", - )); - } + if body.one_time_keys.iter().any(|(u, _)| u.server_name() != services().globals.server_name()) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Tried to access user from other server.", + )); + } - let result = claim_keys_helper(&body.one_time_keys).await?; + let result = claim_keys_helper(&body.one_time_keys).await?; - Ok(claim_keys::v1::Response { - one_time_keys: result.one_time_keys, - }) + Ok(claim_keys::v1::Response { + one_time_keys: result.one_time_keys, + }) } /// # `GET /.well-known/matrix/server` pub async fn well_known_server_route() -> Result { - let server_url = match services().globals.well_known_server() { - Some(url) => url.clone(), - None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), - }; + let server_url = match services().globals.well_known_server() { + Some(url) => url.clone(), + None => return Err(Error::BadRequest(ErrorKind::NotFound, "Not found.")), + }; - Ok(Json(serde_json::json!({ - "m.server": server_url - }))) + Ok(Json(serde_json::json!({ + "m.server": server_url + }))) } #[cfg(test)] mod tests { - use super::{add_port_to_hostname, get_ip_with_port, FedDest}; + use super::{add_port_to_hostname, get_ip_with_port, FedDest}; - #[test] - fn ips_get_default_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1"), - Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("dead:beef::"), - Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) - ); - } + #[test] + fn ips_get_default_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1"), + Some(FedDest::Literal("1.1.1.1:8448".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("dead:beef::"), + Some(FedDest::Literal("[dead:beef::]:8448".parse().unwrap())) + ); + } - #[test] - fn ips_keep_custom_ports() { - assert_eq!( - get_ip_with_port("1.1.1.1:1234"), - Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) - ); - assert_eq!( - get_ip_with_port("[dead::beef]:8933"), - Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) - ); - } + #[test] + fn ips_keep_custom_ports() { + assert_eq!( + get_ip_with_port("1.1.1.1:1234"), + Some(FedDest::Literal("1.1.1.1:1234".parse().unwrap())) + ); + assert_eq!( + get_ip_with_port("[dead::beef]:8933"), + Some(FedDest::Literal("[dead::beef]:8933".parse().unwrap())) + ); + } - #[test] - fn hostnames_get_default_ports() { - assert_eq!( - add_port_to_hostname("example.com"), - FedDest::Named(String::from("example.com"), String::from(":8448")) - ) - } + #[test] + fn hostnames_get_default_ports() { + assert_eq!( + add_port_to_hostname("example.com"), + FedDest::Named(String::from("example.com"), String::from(":8448")) + ) + } - #[test] - fn hostnames_keep_custom_ports() { - assert_eq!( - add_port_to_hostname("example.com:1337"), - FedDest::Named(String::from("example.com"), String::from(":1337")) - ) - } + #[test] + fn hostnames_keep_custom_ports() { + assert_eq!( + add_port_to_hostname("example.com:1337"), + FedDest::Named(String::from("example.com"), String::from(":1337")) + ) + } } diff --git a/src/config/mod.rs b/src/config/mod.rs index c3ed5270..c5387dcf 100644 --- a/src/config/mod.rs +++ b/src/config/mod.rs @@ -1,9 +1,9 @@ use std::{ - collections::BTreeMap, - fmt, - fmt::Write as _, - net::{IpAddr, Ipv4Addr}, - path::PathBuf, + collections::BTreeMap, + fmt, + fmt::Write as _, + net::{IpAddr, Ipv4Addr}, + path::PathBuf, }; use either::Either; @@ -21,539 +21,464 @@ mod proxy; #[derive(Deserialize, Clone, Debug)] #[serde(transparent)] pub struct ListeningPort { - #[serde(with = "either::serde_untagged")] - pub ports: Either>, + #[serde(with = "either::serde_untagged")] + pub ports: Either>, } /// all the config options for conduwuit #[derive(Clone, Debug, Deserialize)] pub struct Config { - /// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6) - #[serde(default = "default_address")] - pub address: IpAddr, - /// default TCP port(s) conduwuit will listen on - #[serde(default = "default_port")] - pub port: ListeningPort, - pub tls: Option, - pub unix_socket_path: Option, - #[serde(default = "default_unix_socket_perms")] - pub unix_socket_perms: u32, - pub server_name: OwnedServerName, - #[serde(default = "default_database_backend")] - pub database_backend: String, - pub database_path: String, - #[serde(default = "default_db_cache_capacity_mb")] - pub db_cache_capacity_mb: f64, - #[serde(default = "default_new_user_displayname_suffix")] - pub new_user_displayname_suffix: String, - #[serde(default = "true_fn")] - pub allow_check_for_updates: bool, - #[serde(default = "default_conduit_cache_capacity_modifier")] - pub conduit_cache_capacity_modifier: f64, - #[serde(default = "default_pdu_cache_capacity")] - pub pdu_cache_capacity: u32, - #[serde(default = "default_cleanup_second_interval")] - pub cleanup_second_interval: u32, - #[serde(default = "default_max_request_size")] - pub max_request_size: u32, - #[serde(default = "default_max_concurrent_requests")] - pub max_concurrent_requests: u16, - #[serde(default = "default_max_fetch_prev_events")] - pub max_fetch_prev_events: u16, - #[serde(default)] - pub allow_registration: bool, - #[serde(default)] - pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool, - pub registration_token: Option, - #[serde(default = "true_fn")] - pub allow_encryption: bool, - #[serde(default = "true_fn")] - pub allow_federation: bool, - #[serde(default)] - pub allow_public_room_directory_over_federation: bool, - #[serde(default)] - pub allow_public_room_directory_without_auth: bool, - #[serde(default)] - pub allow_device_name_federation: bool, - #[serde(default = "true_fn")] - pub allow_room_creation: bool, - #[serde(default = "true_fn")] - pub allow_unstable_room_versions: bool, - #[serde(default = "default_default_room_version")] - pub default_room_version: RoomVersionId, - pub well_known_client: Option, - pub well_known_server: Option, - #[serde(default)] - pub allow_jaeger: bool, - #[serde(default)] - pub tracing_flame: bool, - #[serde(default)] - pub proxy: ProxyConfig, - pub jwt_secret: Option, - #[serde(default = "default_trusted_servers")] - pub trusted_servers: Vec, - #[serde(default = "true_fn")] - pub query_trusted_key_servers_first: bool, - #[serde(default = "default_log")] - pub log: String, - #[serde(default)] - pub turn_username: String, - #[serde(default)] - pub turn_password: String, - #[serde(default = "Vec::new")] - pub turn_uris: Vec, - #[serde(default)] - pub turn_secret: String, - #[serde(default = "default_turn_ttl")] - pub turn_ttl: u64, + /// [`IpAddr`] conduwuit will listen on (can be IPv4 or IPv6) + #[serde(default = "default_address")] + pub address: IpAddr, + /// default TCP port(s) conduwuit will listen on + #[serde(default = "default_port")] + pub port: ListeningPort, + pub tls: Option, + pub unix_socket_path: Option, + #[serde(default = "default_unix_socket_perms")] + pub unix_socket_perms: u32, + pub server_name: OwnedServerName, + #[serde(default = "default_database_backend")] + pub database_backend: String, + pub database_path: String, + #[serde(default = "default_db_cache_capacity_mb")] + pub db_cache_capacity_mb: f64, + #[serde(default = "default_new_user_displayname_suffix")] + pub new_user_displayname_suffix: String, + #[serde(default = "true_fn")] + pub allow_check_for_updates: bool, + #[serde(default = "default_conduit_cache_capacity_modifier")] + pub conduit_cache_capacity_modifier: f64, + #[serde(default = "default_pdu_cache_capacity")] + pub pdu_cache_capacity: u32, + #[serde(default = "default_cleanup_second_interval")] + pub cleanup_second_interval: u32, + #[serde(default = "default_max_request_size")] + pub max_request_size: u32, + #[serde(default = "default_max_concurrent_requests")] + pub max_concurrent_requests: u16, + #[serde(default = "default_max_fetch_prev_events")] + pub max_fetch_prev_events: u16, + #[serde(default)] + pub allow_registration: bool, + #[serde(default)] + pub yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse: bool, + pub registration_token: Option, + #[serde(default = "true_fn")] + pub allow_encryption: bool, + #[serde(default = "true_fn")] + pub allow_federation: bool, + #[serde(default)] + pub allow_public_room_directory_over_federation: bool, + #[serde(default)] + pub allow_public_room_directory_without_auth: bool, + #[serde(default)] + pub allow_device_name_federation: bool, + #[serde(default = "true_fn")] + pub allow_room_creation: bool, + #[serde(default = "true_fn")] + pub allow_unstable_room_versions: bool, + #[serde(default = "default_default_room_version")] + pub default_room_version: RoomVersionId, + pub well_known_client: Option, + pub well_known_server: Option, + #[serde(default)] + pub allow_jaeger: bool, + #[serde(default)] + pub tracing_flame: bool, + #[serde(default)] + pub proxy: ProxyConfig, + pub jwt_secret: Option, + #[serde(default = "default_trusted_servers")] + pub trusted_servers: Vec, + #[serde(default = "true_fn")] + pub query_trusted_key_servers_first: bool, + #[serde(default = "default_log")] + pub log: String, + #[serde(default)] + pub turn_username: String, + #[serde(default)] + pub turn_password: String, + #[serde(default = "Vec::new")] + pub turn_uris: Vec, + #[serde(default)] + pub turn_secret: String, + #[serde(default = "default_turn_ttl")] + pub turn_ttl: u64, - #[serde(default = "default_rocksdb_log_level")] - pub rocksdb_log_level: String, - #[serde(default = "default_rocksdb_max_log_file_size")] - pub rocksdb_max_log_file_size: usize, - #[serde(default = "default_rocksdb_log_time_to_roll")] - pub rocksdb_log_time_to_roll: usize, - #[serde(default)] - pub rocksdb_optimize_for_spinning_disks: bool, - #[serde(default = "default_rocksdb_parallelism_threads")] - pub rocksdb_parallelism_threads: usize, + #[serde(default = "default_rocksdb_log_level")] + pub rocksdb_log_level: String, + #[serde(default = "default_rocksdb_max_log_file_size")] + pub rocksdb_max_log_file_size: usize, + #[serde(default = "default_rocksdb_log_time_to_roll")] + pub rocksdb_log_time_to_roll: usize, + #[serde(default)] + pub rocksdb_optimize_for_spinning_disks: bool, + #[serde(default = "default_rocksdb_parallelism_threads")] + pub rocksdb_parallelism_threads: usize, - pub emergency_password: Option, + pub emergency_password: Option, - #[serde(default = "default_notification_push_path")] - pub notification_push_path: String, + #[serde(default = "default_notification_push_path")] + pub notification_push_path: String, - #[serde(default)] - pub allow_local_presence: bool, - #[serde(default)] - pub allow_incoming_presence: bool, - #[serde(default)] - pub allow_outgoing_presence: bool, - #[serde(default = "default_presence_idle_timeout_s")] - pub presence_idle_timeout_s: u64, - #[serde(default = "default_presence_offline_timeout_s")] - pub presence_offline_timeout_s: u64, + #[serde(default)] + pub allow_local_presence: bool, + #[serde(default)] + pub allow_incoming_presence: bool, + #[serde(default)] + pub allow_outgoing_presence: bool, + #[serde(default = "default_presence_idle_timeout_s")] + pub presence_idle_timeout_s: u64, + #[serde(default = "default_presence_offline_timeout_s")] + pub presence_offline_timeout_s: u64, - #[serde(default)] - pub zstd_compression: bool, + #[serde(default)] + pub zstd_compression: bool, - #[serde(default)] - pub allow_guest_registration: bool, + #[serde(default)] + pub allow_guest_registration: bool, - #[serde(default = "Vec::new")] - pub prevent_media_downloads_from: Vec, + #[serde(default = "Vec::new")] + pub prevent_media_downloads_from: Vec, - #[serde(default = "default_ip_range_denylist")] - pub ip_range_denylist: Vec, + #[serde(default = "default_ip_range_denylist")] + pub ip_range_denylist: Vec, - #[serde(default = "Vec::new")] - pub url_preview_domain_contains_allowlist: Vec, - #[serde(default = "Vec::new")] - pub url_preview_domain_explicit_allowlist: Vec, - #[serde(default = "Vec::new")] - pub url_preview_url_contains_allowlist: Vec, - #[serde(default = "default_url_preview_max_spider_size")] - pub url_preview_max_spider_size: usize, - #[serde(default)] - pub url_preview_check_root_domain: bool, + #[serde(default = "Vec::new")] + pub url_preview_domain_contains_allowlist: Vec, + #[serde(default = "Vec::new")] + pub url_preview_domain_explicit_allowlist: Vec, + #[serde(default = "Vec::new")] + pub url_preview_url_contains_allowlist: Vec, + #[serde(default = "default_url_preview_max_spider_size")] + pub url_preview_max_spider_size: usize, + #[serde(default)] + pub url_preview_check_root_domain: bool, - #[serde(default = "RegexSet::empty")] - #[serde(with = "serde_regex")] - pub forbidden_room_names: RegexSet, + #[serde(default = "RegexSet::empty")] + #[serde(with = "serde_regex")] + pub forbidden_room_names: RegexSet, - #[serde(default = "RegexSet::empty")] - #[serde(with = "serde_regex")] - pub forbidden_usernames: RegexSet, + #[serde(default = "RegexSet::empty")] + #[serde(with = "serde_regex")] + pub forbidden_usernames: RegexSet, - #[serde(default)] - pub block_non_admin_invites: bool, + #[serde(default)] + pub block_non_admin_invites: bool, - #[serde(flatten)] - pub catchall: BTreeMap, + #[serde(flatten)] + pub catchall: BTreeMap, } #[derive(Clone, Debug, Deserialize)] pub struct TlsConfig { - pub certs: String, - pub key: String, - #[serde(default)] - /// Whether to listen and allow for HTTP and HTTPS connections (insecure!) - /// Only works / does something if the `axum_dual_protocol` feature flag was built - pub dual_protocol: bool, + pub certs: String, + pub key: String, + #[serde(default)] + /// Whether to listen and allow for HTTP and HTTPS connections (insecure!) + /// Only works / does something if the `axum_dual_protocol` feature flag was + /// built + pub dual_protocol: bool, } const DEPRECATED_KEYS: &[&str] = &["cache_capacity"]; impl Config { - /// Iterates over all the keys in the config file and warns if there is a deprecated key specified - pub fn warn_deprecated(&self) { - debug!("Checking for deprecated config keys"); - let mut was_deprecated = false; - for key in self - .catchall - .keys() - .filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) - { - warn!("Config parameter \"{}\" is deprecated, ignoring.", key); - was_deprecated = true; - } + /// Iterates over all the keys in the config file and warns if there is a + /// deprecated key specified + pub fn warn_deprecated(&self) { + debug!("Checking for deprecated config keys"); + let mut was_deprecated = false; + for key in self.catchall.keys().filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) { + warn!("Config parameter \"{}\" is deprecated, ignoring.", key); + was_deprecated = true; + } - if was_deprecated { - warn!("Read conduit documentation and check your configuration if any new configuration parameters should be adjusted"); - } - } + if was_deprecated { + warn!( + "Read conduit documentation and check your configuration if any new configuration parameters should \ + be adjusted" + ); + } + } - /// iterates over all the catchall keys (unknown config options) and warns if there are any. - pub fn warn_unknown_key(&self) { - debug!("Checking for unknown config keys"); - for key in self.catchall.keys().filter( - |key| "config".to_owned().ne(key.to_owned()), /* "config" is expected */ - ) { - warn!( - "Config parameter \"{}\" is unknown to conduwuit, ignoring.", - key - ); - } - } + /// iterates over all the catchall keys (unknown config options) and warns + /// if there are any. + pub fn warn_unknown_key(&self) { + debug!("Checking for unknown config keys"); + for key in + self.catchall.keys().filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */) + { + warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key); + } + } - /// Checks the presence of the `address` and `unix_socket_path` keys in the raw_config, exiting the process if both keys were detected. - pub fn is_dual_listening(&self, raw_config: Figment) -> bool { - let check_address = raw_config.find_value("address"); - let check_unix_socket = raw_config.find_value("unix_socket_path"); + /// Checks the presence of the `address` and `unix_socket_path` keys in the + /// raw_config, exiting the process if both keys were detected. + pub fn is_dual_listening(&self, raw_config: Figment) -> bool { + let check_address = raw_config.find_value("address"); + let check_unix_socket = raw_config.find_value("unix_socket_path"); - // are the check_address and check_unix_socket keys both Ok (specified) at the same time? - if check_address.is_ok() && check_unix_socket.is_ok() { - error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option."); - return true; - } + // are the check_address and check_unix_socket keys both Ok (specified) at the + // same time? + if check_address.is_ok() && check_unix_socket.is_ok() { + error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option."); + return true; + } - false - } + false + } } impl fmt::Display for Config { - fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { - // Prepare a list of config values to show - let lines = [ - ("Server name", self.server_name.host()), - ("Database backend", &self.database_backend), - ("Database path", &self.database_path), - ( - "Database cache capacity (MB)", - &self.db_cache_capacity_mb.to_string(), - ), - ( - "Cache capacity modifier", - &self.conduit_cache_capacity_modifier.to_string(), - ), - ("PDU cache capacity", &self.pdu_cache_capacity.to_string()), - ( - "Cleanup interval in seconds", - &self.cleanup_second_interval.to_string(), - ), - ("Maximum request size (bytes)", &self.max_request_size.to_string()), - ( - "Maximum concurrent requests", - &self.max_concurrent_requests.to_string(), - ), - ( - "Allow registration", - &self.allow_registration.to_string(), - ), - ( - "Registration token", - match self.registration_token { - Some(_) => "set", - None => "not set (open registration!)", - }, - ), - ( - "Allow guest registration (inherently false if allow registration is false)", - &self.allow_guest_registration.to_string(), - ), - ( - "New user display name suffix", - &self.new_user_displayname_suffix, - ), - ("Allow encryption", &self.allow_encryption.to_string()), - ("Allow federation", &self.allow_federation.to_string()), - ( - "Allow incoming federated presence requests (updates)", - &self.allow_incoming_presence.to_string(), - ), - ( - "Allow outgoing federated presence requests (updates)", - &self.allow_outgoing_presence.to_string(), - ), - ( - "Allow local presence requests (updates)", - &self.allow_local_presence.to_string(), - ), - ( - "Block non-admin room invites (local and remote, admins can still send and receive invites)", - &self.block_non_admin_invites.to_string(), - ), - ( - "Allow device name federation", - &self.allow_device_name_federation.to_string(), - ), - ("Notification push path", &self.notification_push_path), - ("Allow room creation", &self.allow_room_creation.to_string()), - ( - "Allow public room directory over federation", - &self.allow_public_room_directory_over_federation.to_string(), - ), - ( - "Allow public room directory without authentication", - &self.allow_public_room_directory_without_auth.to_string(), - ), - ( - "JWT secret", - match self.jwt_secret { - Some(_) => "set", - None => "not set", - }, - ), - ("Trusted servers", { - let mut lst = vec![]; - for server in &self.trusted_servers { - lst.push(server.host()); - } - &lst.join(", ") - }), - ( - "Query Trusted Key Servers First", - &self.query_trusted_key_servers_first.to_string(), - ), - ( - "TURN username", - if self.turn_username.is_empty() { - "not set" - } else { - &self.turn_username - }, - ), - ("TURN password", { - if self.turn_password.is_empty() { - "not set" - } else { - "set" - } - }), - ("TURN secret", { - if self.turn_secret.is_empty() { - "not set" - } else { - "set" - } - }), - ("Turn TTL", &self.turn_ttl.to_string()), - ("Turn URIs", { - let mut lst = vec![]; - for item in self.turn_uris.iter().cloned().enumerate() { - let (_, uri): (usize, String) = item; - lst.push(uri); - } - &lst.join(", ") - }), - ( - "zstd Response Body Compression", - &self.zstd_compression.to_string(), - ), - ("RocksDB database log level", &self.rocksdb_log_level), - ( - "RocksDB database log time-to-roll", - &self.rocksdb_log_time_to_roll.to_string(), - ), - ( - "RocksDB database max log file size", - &self.rocksdb_max_log_file_size.to_string(), - ), - ( - "RocksDB database optimize for spinning disks", - &self.rocksdb_optimize_for_spinning_disks.to_string(), - ), - ( - "RocksDB Parallelism Threads", - &self.rocksdb_parallelism_threads.to_string(), - ), - ("Prevent Media Downloads From", { - let mut lst = vec![]; - for domain in &self.prevent_media_downloads_from { - lst.push(domain.host()); - } - &lst.join(", ") - }), - ("Outbound Request IP Range Denylist", { - let mut lst = vec![]; - for item in self.ip_range_denylist.iter().cloned().enumerate() { - let (_, ip): (usize, String) = item; - lst.push(ip); - } - &lst.join(", ") - }), - ("Forbidden usernames", { - &self.forbidden_usernames.patterns().iter().join(", ") - }), - ("Forbidden room names", { - &self.forbidden_room_names.patterns().iter().join(", ") - }), - ( - "URL preview domain contains allowlist", - &self.url_preview_domain_contains_allowlist.join(", "), - ), - ( - "URL preview domain explicit allowlist", - &self.url_preview_domain_explicit_allowlist.join(", "), - ), - ( - "URL preview URL contains allowlist", - &self.url_preview_url_contains_allowlist.join(", "), - ), - ( - "URL preview maximum spider size", - &self.url_preview_max_spider_size.to_string(), - ), - ( - "URL preview check root domain", - &self.url_preview_check_root_domain.to_string(), - ), - ]; + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + // Prepare a list of config values to show + let lines = [ + ("Server name", self.server_name.host()), + ("Database backend", &self.database_backend), + ("Database path", &self.database_path), + ("Database cache capacity (MB)", &self.db_cache_capacity_mb.to_string()), + ("Cache capacity modifier", &self.conduit_cache_capacity_modifier.to_string()), + ("PDU cache capacity", &self.pdu_cache_capacity.to_string()), + ("Cleanup interval in seconds", &self.cleanup_second_interval.to_string()), + ("Maximum request size (bytes)", &self.max_request_size.to_string()), + ("Maximum concurrent requests", &self.max_concurrent_requests.to_string()), + ("Allow registration", &self.allow_registration.to_string()), + ( + "Registration token", + match self.registration_token { + Some(_) => "set", + None => "not set (open registration!)", + }, + ), + ( + "Allow guest registration (inherently false if allow registration is false)", + &self.allow_guest_registration.to_string(), + ), + ("New user display name suffix", &self.new_user_displayname_suffix), + ("Allow encryption", &self.allow_encryption.to_string()), + ("Allow federation", &self.allow_federation.to_string()), + ( + "Allow incoming federated presence requests (updates)", + &self.allow_incoming_presence.to_string(), + ), + ( + "Allow outgoing federated presence requests (updates)", + &self.allow_outgoing_presence.to_string(), + ), + ( + "Allow local presence requests (updates)", + &self.allow_local_presence.to_string(), + ), + ( + "Block non-admin room invites (local and remote, admins can still send and receive invites)", + &self.block_non_admin_invites.to_string(), + ), + ("Allow device name federation", &self.allow_device_name_federation.to_string()), + ("Notification push path", &self.notification_push_path), + ("Allow room creation", &self.allow_room_creation.to_string()), + ( + "Allow public room directory over federation", + &self.allow_public_room_directory_over_federation.to_string(), + ), + ( + "Allow public room directory without authentication", + &self.allow_public_room_directory_without_auth.to_string(), + ), + ( + "JWT secret", + match self.jwt_secret { + Some(_) => "set", + None => "not set", + }, + ), + ("Trusted servers", { + let mut lst = vec![]; + for server in &self.trusted_servers { + lst.push(server.host()); + } + &lst.join(", ") + }), + ( + "Query Trusted Key Servers First", + &self.query_trusted_key_servers_first.to_string(), + ), + ( + "TURN username", + if self.turn_username.is_empty() { + "not set" + } else { + &self.turn_username + }, + ), + ("TURN password", { + if self.turn_password.is_empty() { + "not set" + } else { + "set" + } + }), + ("TURN secret", { + if self.turn_secret.is_empty() { + "not set" + } else { + "set" + } + }), + ("Turn TTL", &self.turn_ttl.to_string()), + ("Turn URIs", { + let mut lst = vec![]; + for item in self.turn_uris.iter().cloned().enumerate() { + let (_, uri): (usize, String) = item; + lst.push(uri); + } + &lst.join(", ") + }), + ("zstd Response Body Compression", &self.zstd_compression.to_string()), + ("RocksDB database log level", &self.rocksdb_log_level), + ("RocksDB database log time-to-roll", &self.rocksdb_log_time_to_roll.to_string()), + ( + "RocksDB database max log file size", + &self.rocksdb_max_log_file_size.to_string(), + ), + ( + "RocksDB database optimize for spinning disks", + &self.rocksdb_optimize_for_spinning_disks.to_string(), + ), + ("RocksDB Parallelism Threads", &self.rocksdb_parallelism_threads.to_string()), + ("Prevent Media Downloads From", { + let mut lst = vec![]; + for domain in &self.prevent_media_downloads_from { + lst.push(domain.host()); + } + &lst.join(", ") + }), + ("Outbound Request IP Range Denylist", { + let mut lst = vec![]; + for item in self.ip_range_denylist.iter().cloned().enumerate() { + let (_, ip): (usize, String) = item; + lst.push(ip); + } + &lst.join(", ") + }), + ("Forbidden usernames", { + &self.forbidden_usernames.patterns().iter().join(", ") + }), + ("Forbidden room names", { + &self.forbidden_room_names.patterns().iter().join(", ") + }), + ( + "URL preview domain contains allowlist", + &self.url_preview_domain_contains_allowlist.join(", "), + ), + ( + "URL preview domain explicit allowlist", + &self.url_preview_domain_explicit_allowlist.join(", "), + ), + ( + "URL preview URL contains allowlist", + &self.url_preview_url_contains_allowlist.join(", "), + ), + ("URL preview maximum spider size", &self.url_preview_max_spider_size.to_string()), + ("URL preview check root domain", &self.url_preview_check_root_domain.to_string()), + ]; - let mut msg: String = "Active config values:\n\n".to_owned(); + let mut msg: String = "Active config values:\n\n".to_owned(); - for line in lines.into_iter().enumerate() { - let _ = writeln!(msg, "{}: {}", line.1 .0, line.1 .1); - } + for line in lines.into_iter().enumerate() { + let _ = writeln!(msg, "{}: {}", line.1 .0, line.1 .1); + } - write!(f, "{msg}") - } + write!(f, "{msg}") + } } -fn true_fn() -> bool { - true -} +fn true_fn() -> bool { true } -fn default_address() -> IpAddr { - Ipv4Addr::LOCALHOST.into() -} +fn default_address() -> IpAddr { Ipv4Addr::LOCALHOST.into() } fn default_port() -> ListeningPort { - ListeningPort { - ports: Either::Left(8008), - } + ListeningPort { + ports: Either::Left(8008), + } } -fn default_unix_socket_perms() -> u32 { - 660 -} +fn default_unix_socket_perms() -> u32 { 660 } -fn default_database_backend() -> String { - "rocksdb".to_owned() -} +fn default_database_backend() -> String { "rocksdb".to_owned() } -fn default_db_cache_capacity_mb() -> f64 { - 300.0 -} +fn default_db_cache_capacity_mb() -> f64 { 300.0 } -fn default_conduit_cache_capacity_modifier() -> f64 { - 1.0 -} +fn default_conduit_cache_capacity_modifier() -> f64 { 1.0 } -fn default_pdu_cache_capacity() -> u32 { - 150_000 -} +fn default_pdu_cache_capacity() -> u32 { 150_000 } fn default_cleanup_second_interval() -> u32 { - 60 // every minute + 60 // every minute } fn default_max_request_size() -> u32 { - 20 * 1024 * 1024 // Default to 20 MB + 20 * 1024 * 1024 // Default to 20 MB } -fn default_max_concurrent_requests() -> u16 { - 500 -} +fn default_max_concurrent_requests() -> u16 { 500 } -fn default_max_fetch_prev_events() -> u16 { - 100_u16 -} +fn default_max_fetch_prev_events() -> u16 { 100_u16 } -fn default_trusted_servers() -> Vec { - vec![OwnedServerName::try_from("matrix.org").unwrap()] -} +fn default_trusted_servers() -> Vec { vec![OwnedServerName::try_from("matrix.org").unwrap()] } -fn default_log() -> String { - "warn,state_res=warn".to_owned() -} +fn default_log() -> String { "warn,state_res=warn".to_owned() } -fn default_notification_push_path() -> String { - "/_matrix/push/v1/notify".to_owned() -} +fn default_notification_push_path() -> String { "/_matrix/push/v1/notify".to_owned() } -fn default_turn_ttl() -> u64 { - 60 * 60 * 24 -} +fn default_turn_ttl() -> u64 { 60 * 60 * 24 } -fn default_presence_idle_timeout_s() -> u64 { - 5 * 60 -} +fn default_presence_idle_timeout_s() -> u64 { 5 * 60 } -fn default_presence_offline_timeout_s() -> u64 { - 30 * 60 -} +fn default_presence_offline_timeout_s() -> u64 { 30 * 60 } -fn default_rocksdb_log_level() -> String { - "warn".to_owned() -} +fn default_rocksdb_log_level() -> String { "warn".to_owned() } -fn default_rocksdb_log_time_to_roll() -> usize { - 0 -} +fn default_rocksdb_log_time_to_roll() -> usize { 0 } -fn default_rocksdb_parallelism_threads() -> usize { - num_cpus::get_physical() / 2 -} +fn default_rocksdb_parallelism_threads() -> usize { num_cpus::get_physical() / 2 } // I know, it's a great name -pub(crate) fn default_default_room_version() -> RoomVersionId { - RoomVersionId::V10 -} +pub(crate) fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 } fn default_rocksdb_max_log_file_size() -> usize { - // 4 megabytes - 4 * 1024 * 1024 + // 4 megabytes + 4 * 1024 * 1024 } fn default_ip_range_denylist() -> Vec { - vec![ - "127.0.0.0/8".to_owned(), - "10.0.0.0/8".to_owned(), - "172.16.0.0/12".to_owned(), - "192.168.0.0/16".to_owned(), - "100.64.0.0/10".to_owned(), - "192.0.0.0/24".to_owned(), - "169.254.0.0/16".to_owned(), - "192.88.99.0/24".to_owned(), - "198.18.0.0/15".to_owned(), - "192.0.2.0/24".to_owned(), - "198.51.100.0/24".to_owned(), - "203.0.113.0/24".to_owned(), - "224.0.0.0/4".to_owned(), - "::1/128".to_owned(), - "fe80::/10".to_owned(), - "fc00::/7".to_owned(), - "2001:db8::/32".to_owned(), - "ff00::/8".to_owned(), - "fec0::/10".to_owned(), - ] + vec![ + "127.0.0.0/8".to_owned(), + "10.0.0.0/8".to_owned(), + "172.16.0.0/12".to_owned(), + "192.168.0.0/16".to_owned(), + "100.64.0.0/10".to_owned(), + "192.0.0.0/24".to_owned(), + "169.254.0.0/16".to_owned(), + "192.88.99.0/24".to_owned(), + "198.18.0.0/15".to_owned(), + "192.0.2.0/24".to_owned(), + "198.51.100.0/24".to_owned(), + "203.0.113.0/24".to_owned(), + "224.0.0.0/4".to_owned(), + "::1/128".to_owned(), + "fe80::/10".to_owned(), + "fc00::/7".to_owned(), + "2001:db8::/32".to_owned(), + "ff00::/8".to_owned(), + "fec0::/10".to_owned(), + ] } fn default_url_preview_max_spider_size() -> usize { - 1_000_000 // 1MB + 1_000_000 // 1MB } -fn default_new_user_displayname_suffix() -> String { - "🏳️‍⚧️".to_owned() -} +fn default_new_user_displayname_suffix() -> String { "🏳️‍⚧️".to_owned() } diff --git a/src/config/proxy.rs b/src/config/proxy.rs index 80eee6af..54572648 100644 --- a/src/config/proxy.rs +++ b/src/config/proxy.rs @@ -24,119 +24,124 @@ use crate::Result; /// ## Include vs. Exclude /// If include is an empty list, it is assumed to be `["*"]`. /// -/// If a domain matches both the exclude and include list, the proxy will only be used if it was -/// included because of a more specific rule than it was excluded. In the above example, the proxy -/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`. +/// If a domain matches both the exclude and include list, the proxy will only +/// be used if it was included because of a more specific rule than it was +/// excluded. In the above example, the proxy would be used for +/// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`. #[derive(Clone, Default, Debug, Deserialize)] #[serde(rename_all = "snake_case")] pub enum ProxyConfig { - #[default] - None, - Global { - #[serde(deserialize_with = "crate::utils::deserialize_from_str")] - url: Url, - }, - ByDomain(Vec), + #[default] + None, + Global { + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: Url, + }, + ByDomain(Vec), } impl ProxyConfig { - pub fn to_proxy(&self) -> Result> { - Ok(match self.clone() { - ProxyConfig::None => None, - ProxyConfig::Global { url } => Some(Proxy::all(url)?), - ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { - proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy - })), - }) - } + pub fn to_proxy(&self) -> Result> { + Ok(match self.clone() { + ProxyConfig::None => None, + ProxyConfig::Global { + url, + } => Some(Proxy::all(url)?), + ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| { + proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching + // proxy + })), + }) + } } #[derive(Clone, Debug, Deserialize)] pub struct PartialProxyConfig { - #[serde(deserialize_with = "crate::utils::deserialize_from_str")] - url: Url, - #[serde(default)] - include: Vec, - #[serde(default)] - exclude: Vec, + #[serde(deserialize_with = "crate::utils::deserialize_from_str")] + url: Url, + #[serde(default)] + include: Vec, + #[serde(default)] + exclude: Vec, } impl PartialProxyConfig { - pub fn for_url(&self, url: &Url) -> Option<&Url> { - let domain = url.domain()?; - let mut included_because = None; // most specific reason it was included - let mut excluded_because = None; // most specific reason it was excluded - if self.include.is_empty() { - // treat empty include list as `*` - included_because = Some(&WildCardedDomain::WildCard); - } - for wc_domain in &self.include { - if wc_domain.matches(domain) { - match included_because { - Some(prev) if !wc_domain.more_specific_than(prev) => (), - _ => included_because = Some(wc_domain), - } - } - } - for wc_domain in &self.exclude { - if wc_domain.matches(domain) { - match excluded_because { - Some(prev) if !wc_domain.more_specific_than(prev) => (), - _ => excluded_because = Some(wc_domain), - } - } - } - match (included_because, excluded_because) { - (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded - (Some(_), None) => Some(&self.url), - _ => None, - } - } + pub fn for_url(&self, url: &Url) -> Option<&Url> { + let domain = url.domain()?; + let mut included_because = None; // most specific reason it was included + let mut excluded_because = None; // most specific reason it was excluded + if self.include.is_empty() { + // treat empty include list as `*` + included_because = Some(&WildCardedDomain::WildCard); + } + for wc_domain in &self.include { + if wc_domain.matches(domain) { + match included_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => included_because = Some(wc_domain), + } + } + } + for wc_domain in &self.exclude { + if wc_domain.matches(domain) { + match excluded_because { + Some(prev) if !wc_domain.more_specific_than(prev) => (), + _ => excluded_because = Some(wc_domain), + } + } + } + match (included_because, excluded_because) { + (Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), /* included for a more specific reason */ + // than excluded + (Some(_), None) => Some(&self.url), + _ => None, + } + } } /// A domain name, that optionally allows a * as its first subdomain. #[derive(Clone, Debug)] enum WildCardedDomain { - WildCard, - WildCarded(String), - Exact(String), + WildCard, + WildCarded(String), + Exact(String), } impl WildCardedDomain { - fn matches(&self, domain: &str) -> bool { - match self { - WildCardedDomain::WildCard => true, - WildCardedDomain::WildCarded(d) => domain.ends_with(d), - WildCardedDomain::Exact(d) => domain == d, - } - } - fn more_specific_than(&self, other: &Self) -> bool { - match (self, other) { - (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, - (_, WildCardedDomain::WildCard) => true, - (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), - (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => { - a != b && a.ends_with(b) - } - _ => false, - } - } + fn matches(&self, domain: &str) -> bool { + match self { + WildCardedDomain::WildCard => true, + WildCardedDomain::WildCarded(d) => domain.ends_with(d), + WildCardedDomain::Exact(d) => domain == d, + } + } + + fn more_specific_than(&self, other: &Self) -> bool { + match (self, other) { + (WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false, + (_, WildCardedDomain::WildCard) => true, + (WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a), + (WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => a != b && a.ends_with(b), + _ => false, + } + } } impl std::str::FromStr for WildCardedDomain { - type Err = std::convert::Infallible; - fn from_str(s: &str) -> Result { - // maybe do some domain validation? - Ok(if s.starts_with("*.") { - WildCardedDomain::WildCarded(s[1..].to_owned()) - } else if s == "*" { - WildCardedDomain::WildCarded("".to_owned()) - } else { - WildCardedDomain::Exact(s.to_owned()) - }) - } + type Err = std::convert::Infallible; + + fn from_str(s: &str) -> Result { + // maybe do some domain validation? + Ok(if s.starts_with("*.") { + WildCardedDomain::WildCarded(s[1..].to_owned()) + } else if s == "*" { + WildCardedDomain::WildCarded("".to_owned()) + } else { + WildCardedDomain::Exact(s.to_owned()) + }) + } } impl<'de> Deserialize<'de> for WildCardedDomain { - fn deserialize(deserializer: D) -> Result - where - D: serde::de::Deserializer<'de>, - { - crate::utils::deserialize_from_str(deserializer) - } + fn deserialize(deserializer: D) -> Result + where + D: serde::de::Deserializer<'de>, + { + crate::utils::deserialize_from_str(deserializer) + } } diff --git a/src/database/abstraction.rs b/src/database/abstraction.rs index 44e24968..ebf4e4b9 100644 --- a/src/database/abstraction.rs +++ b/src/database/abstraction.rs @@ -1,8 +1,8 @@ +use std::{future::Future, pin::Pin, sync::Arc}; + use super::Config; use crate::Result; -use std::{future::Future, pin::Pin, sync::Arc}; - #[cfg(feature = "sqlite")] pub mod sqlite; @@ -13,53 +13,44 @@ pub(crate) mod rocksdb; pub(crate) mod watchers; pub(crate) trait KeyValueDatabaseEngine: Send + Sync { - fn open(config: &Config) -> Result - where - Self: Sized; - fn open_tree(&self, name: &'static str) -> Result>; - fn flush(&self) -> Result<()>; - fn cleanup(&self) -> Result<()> { - Ok(()) - } - fn memory_usage(&self) -> Result { - Ok("Current database engine does not support memory usage reporting.".to_owned()) - } + fn open(config: &Config) -> Result + where + Self: Sized; + fn open_tree(&self, name: &'static str) -> Result>; + fn flush(&self) -> Result<()>; + fn cleanup(&self) -> Result<()> { Ok(()) } + fn memory_usage(&self) -> Result { + Ok("Current database engine does not support memory usage reporting.".to_owned()) + } - #[allow(dead_code)] - fn clear_caches(&self) {} + #[allow(dead_code)] + fn clear_caches(&self) {} } pub(crate) trait KvTree: Send + Sync { - fn get(&self, key: &[u8]) -> Result>>; + fn get(&self, key: &[u8]) -> Result>>; - fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; - fn insert_batch(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()>; + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()>; + fn insert_batch(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()>; - fn remove(&self, key: &[u8]) -> Result<()>; + fn remove(&self, key: &[u8]) -> Result<()>; - fn iter<'a>(&'a self) -> Box, Vec)> + 'a>; + fn iter<'a>(&'a self) -> Box, Vec)> + 'a>; - fn iter_from<'a>( - &'a self, - from: &[u8], - backwards: bool, - ) -> Box, Vec)> + 'a>; + fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box, Vec)> + 'a>; - fn increment(&self, key: &[u8]) -> Result>; - fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()>; + fn increment(&self, key: &[u8]) -> Result>; + fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()>; - fn scan_prefix<'a>( - &'a self, - prefix: Vec, - ) -> Box, Vec)> + 'a>; + fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box, Vec)> + 'a>; - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>>; + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>>; - fn clear(&self) -> Result<()> { - for (key, _) in self.iter() { - self.remove(&key)?; - } + fn clear(&self) -> Result<()> { + for (key, _) in self.iter() { + self.remove(&key)?; + } - Ok(()) - } + Ok(()) + } } diff --git a/src/database/abstraction/rocksdb.rs b/src/database/abstraction/rocksdb.rs index 5321e147..373bd86d 100644 --- a/src/database/abstraction/rocksdb.rs +++ b/src/database/abstraction/rocksdb.rs @@ -1,293 +1,265 @@ use std::{ - future::Future, - pin::Pin, - sync::{Arc, RwLock}, + future::Future, + pin::Pin, + sync::{Arc, RwLock}, }; use rocksdb::LogLevel::{Debug, Error, Fatal, Info, Warn}; use tracing::{debug, info}; +use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree}; use crate::{utils, Result}; -use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree}; - pub(crate) struct Engine { - rocks: rocksdb::DBWithThreadMode, - cache: rocksdb::Cache, - old_cfs: Vec, - config: Config, + rocks: rocksdb::DBWithThreadMode, + cache: rocksdb::Cache, + old_cfs: Vec, + config: Config, } struct RocksDbEngineTree<'a> { - db: Arc, - name: &'a str, - watchers: Watchers, - write_lock: RwLock<()>, + db: Arc, + name: &'a str, + watchers: Watchers, + write_lock: RwLock<()>, } fn db_options(rocksdb_cache: &rocksdb::Cache, config: &Config) -> rocksdb::Options { - // block-based options: https://docs.rs/rocksdb/latest/rocksdb/struct.BlockBasedOptions.html# - let mut block_based_options = rocksdb::BlockBasedOptions::default(); + // block-based options: https://docs.rs/rocksdb/latest/rocksdb/struct.BlockBasedOptions.html# + let mut block_based_options = rocksdb::BlockBasedOptions::default(); - block_based_options.set_block_cache(rocksdb_cache); + block_based_options.set_block_cache(rocksdb_cache); - // "Difference of spinning disk" - // https://zhangyuchi.gitbooks.io/rocksdbbook/content/RocksDB-Tuning-Guide.html - block_based_options.set_block_size(64 * 1024); - block_based_options.set_cache_index_and_filter_blocks(true); + // "Difference of spinning disk" + // https://zhangyuchi.gitbooks.io/rocksdbbook/content/RocksDB-Tuning-Guide.html + block_based_options.set_block_size(64 * 1024); + block_based_options.set_cache_index_and_filter_blocks(true); - // database options: https://docs.rs/rocksdb/latest/rocksdb/struct.Options.html# - let mut db_opts = rocksdb::Options::default(); + // database options: https://docs.rs/rocksdb/latest/rocksdb/struct.Options.html# + let mut db_opts = rocksdb::Options::default(); - let rocksdb_log_level = match config.rocksdb_log_level.as_ref() { - "debug" => Debug, - "info" => Info, - "error" => Error, - "fatal" => Fatal, - _ => Warn, - }; + let rocksdb_log_level = match config.rocksdb_log_level.as_ref() { + "debug" => Debug, + "info" => Info, + "error" => Error, + "fatal" => Fatal, + _ => Warn, + }; - let threads = if config.rocksdb_parallelism_threads == 0 { - num_cpus::get_physical() // max cores if user specified 0 - } else { - config.rocksdb_parallelism_threads - }; + let threads = if config.rocksdb_parallelism_threads == 0 { + num_cpus::get_physical() // max cores if user specified 0 + } else { + config.rocksdb_parallelism_threads + }; - db_opts.set_log_level(rocksdb_log_level); - db_opts.set_max_log_file_size(config.rocksdb_max_log_file_size); - db_opts.set_log_file_time_to_roll(config.rocksdb_log_time_to_roll); + db_opts.set_log_level(rocksdb_log_level); + db_opts.set_max_log_file_size(config.rocksdb_max_log_file_size); + db_opts.set_log_file_time_to_roll(config.rocksdb_log_time_to_roll); - if config.rocksdb_optimize_for_spinning_disks { - db_opts.set_skip_stats_update_on_db_open(true); - db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs - db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs - db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage - db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for spinning hard drives. these are not really important - } else { - db_opts.set_skip_stats_update_on_db_open(false); - db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024); - db_opts.set_use_direct_reads(true); - db_opts.set_use_direct_io_for_flush_and_compaction(true); - db_opts.set_keep_log_file_num(20); - } + if config.rocksdb_optimize_for_spinning_disks { + db_opts.set_skip_stats_update_on_db_open(true); + db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs + db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs + db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage + db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for + // spinning hard drives. these are not really + // important + } else { + db_opts.set_skip_stats_update_on_db_open(false); + db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024); + db_opts.set_use_direct_reads(true); + db_opts.set_use_direct_io_for_flush_and_compaction(true); + db_opts.set_keep_log_file_num(20); + } - db_opts.set_block_based_table_factory(&block_based_options); - db_opts.set_level_compaction_dynamic_level_bytes(true); - db_opts.create_if_missing(true); - db_opts.increase_parallelism( - threads - .try_into() - .expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"), - ); - //db_opts.set_max_open_files(config.rocksdb_max_open_files); - db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd); - db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level); - db_opts.optimize_level_style_compaction(10 * 1024 * 1024); + db_opts.set_block_based_table_factory(&block_based_options); + db_opts.set_level_compaction_dynamic_level_bytes(true); + db_opts.create_if_missing(true); + db_opts.increase_parallelism( + threads.try_into().expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"), + ); + //db_opts.set_max_open_files(config.rocksdb_max_open_files); + db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd); + db_opts.set_compaction_style(rocksdb::DBCompactionStyle::Level); + db_opts.optimize_level_style_compaction(10 * 1024 * 1024); - // https://github.com/facebook/rocksdb/wiki/Setup-Options-and-Basic-Tuning - db_opts.set_max_background_jobs(6); - db_opts.set_bytes_per_sync(1_048_576); + // https://github.com/facebook/rocksdb/wiki/Setup-Options-and-Basic-Tuning + db_opts.set_max_background_jobs(6); + db_opts.set_bytes_per_sync(1_048_576); - // https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords - // - // Unclean shutdowns of a Matrix homeserver are likely to be fine when - // recovered in this manner as it's likely any lost information will be - // restored via federation. - db_opts.set_wal_recovery_mode(rocksdb::DBRecoveryMode::TolerateCorruptedTailRecords); + // https://github.com/facebook/rocksdb/wiki/WAL-Recovery-Modes#ktoleratecorruptedtailrecords + // + // Unclean shutdowns of a Matrix homeserver are likely to be fine when + // recovered in this manner as it's likely any lost information will be + // restored via federation. + db_opts.set_wal_recovery_mode(rocksdb::DBRecoveryMode::TolerateCorruptedTailRecords); - let prefix_extractor = rocksdb::SliceTransform::create_fixed_prefix(1); - db_opts.set_prefix_extractor(prefix_extractor); + let prefix_extractor = rocksdb::SliceTransform::create_fixed_prefix(1); + db_opts.set_prefix_extractor(prefix_extractor); - db_opts + db_opts } impl KeyValueDatabaseEngine for Arc { - fn open(config: &Config) -> Result { - let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; - let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes); + fn open(config: &Config) -> Result { + let cache_capacity_bytes = (config.db_cache_capacity_mb * 1024.0 * 1024.0) as usize; + let rocksdb_cache = rocksdb::Cache::new_lru_cache(cache_capacity_bytes); - let db_opts = db_options(&rocksdb_cache, config); + let db_opts = db_options(&rocksdb_cache, config); - debug!("Listing column families in database"); - let cfs = rocksdb::DBWithThreadMode::::list_cf( - &db_opts, - &config.database_path, - ) - .unwrap_or_default(); + debug!("Listing column families in database"); + let cfs = rocksdb::DBWithThreadMode::::list_cf(&db_opts, &config.database_path) + .unwrap_or_default(); - debug!("Opening column family descriptors in database"); - info!("RocksDB database compaction will take place now, a delay in startup is expected"); - let db = rocksdb::DBWithThreadMode::::open_cf_descriptors( - &db_opts, - &config.database_path, - cfs.iter().map(|name| { - rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config)) - }), - )?; + debug!("Opening column family descriptors in database"); + info!("RocksDB database compaction will take place now, a delay in startup is expected"); + let db = rocksdb::DBWithThreadMode::::open_cf_descriptors( + &db_opts, + &config.database_path, + cfs.iter().map(|name| rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config))), + )?; - Ok(Arc::new(Engine { - rocks: db, - cache: rocksdb_cache, - old_cfs: cfs, - config: config.clone(), - })) - } + Ok(Arc::new(Engine { + rocks: db, + cache: rocksdb_cache, + old_cfs: cfs, + config: config.clone(), + })) + } - fn open_tree(&self, name: &'static str) -> Result> { - if !self.old_cfs.contains(&name.to_owned()) { - // Create if it didn't exist - debug!("Creating new column family in database: {}", name); - let _ = self - .rocks - .create_cf(name, &db_options(&self.cache, &self.config)); - } + fn open_tree(&self, name: &'static str) -> Result> { + if !self.old_cfs.contains(&name.to_owned()) { + // Create if it didn't exist + debug!("Creating new column family in database: {}", name); + let _ = self.rocks.create_cf(name, &db_options(&self.cache, &self.config)); + } - Ok(Arc::new(RocksDbEngineTree { - name, - db: Arc::clone(self), - watchers: Watchers::default(), - write_lock: RwLock::new(()), - })) - } + Ok(Arc::new(RocksDbEngineTree { + name, + db: Arc::clone(self), + watchers: Watchers::default(), + write_lock: RwLock::new(()), + })) + } - fn flush(&self) -> Result<()> { - // TODO? - Ok(()) - } + fn flush(&self) -> Result<()> { + // TODO? + Ok(()) + } - fn memory_usage(&self) -> Result { - let stats = - rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?; - Ok(format!( - "Approximate memory usage of all the mem-tables: {:.3} MB\n\ - Approximate memory usage of un-flushed mem-tables: {:.3} MB\n\ - Approximate memory usage of all the table readers: {:.3} MB\n\ - Approximate memory usage by cache: {:.3} MB\n\ - Approximate memory usage by cache pinned: {:.3} MB\n\ - ", - stats.mem_table_total as f64 / 1024.0 / 1024.0, - stats.mem_table_unflushed as f64 / 1024.0 / 1024.0, - stats.mem_table_readers_total as f64 / 1024.0 / 1024.0, - stats.cache_total as f64 / 1024.0 / 1024.0, - self.cache.get_pinned_usage() as f64 / 1024.0 / 1024.0, - )) - } + fn memory_usage(&self) -> Result { + let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?; + Ok(format!( + "Approximate memory usage of all the mem-tables: {:.3} MB\nApproximate memory usage of un-flushed \ + mem-tables: {:.3} MB\nApproximate memory usage of all the table readers: {:.3} MB\nApproximate memory \ + usage by cache: {:.3} MB\nApproximate memory usage by cache pinned: {:.3} MB\n", + stats.mem_table_total as f64 / 1024.0 / 1024.0, + stats.mem_table_unflushed as f64 / 1024.0 / 1024.0, + stats.mem_table_readers_total as f64 / 1024.0 / 1024.0, + stats.cache_total as f64 / 1024.0 / 1024.0, + self.cache.get_pinned_usage() as f64 / 1024.0 / 1024.0, + )) + } - // TODO: figure out if this is needed for rocksdb - #[allow(dead_code)] - fn clear_caches(&self) {} + // TODO: figure out if this is needed for rocksdb + #[allow(dead_code)] + fn clear_caches(&self) {} } impl RocksDbEngineTree<'_> { - fn cf(&self) -> Arc> { - self.db.rocks.cf_handle(self.name).unwrap() - } + fn cf(&self) -> Arc> { self.db.rocks.cf_handle(self.name).unwrap() } } impl KvTree for RocksDbEngineTree<'_> { - fn get(&self, key: &[u8]) -> Result>> { - Ok(self.db.rocks.get_cf(&self.cf(), key)?) - } + fn get(&self, key: &[u8]) -> Result>> { Ok(self.db.rocks.get_cf(&self.cf(), key)?) } - fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - let lock = self.write_lock.read().unwrap(); - self.db.rocks.put_cf(&self.cf(), key, value)?; - drop(lock); + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { + let lock = self.write_lock.read().unwrap(); + self.db.rocks.put_cf(&self.cf(), key, value)?; + drop(lock); - self.watchers.wake(key); + self.watchers.wake(key); - Ok(()) - } + Ok(()) + } - fn insert_batch(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { - for (key, value) in iter { - self.db.rocks.put_cf(&self.cf(), key, value)?; - } + fn insert_batch(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { + for (key, value) in iter { + self.db.rocks.put_cf(&self.cf(), key, value)?; + } - Ok(()) - } + Ok(()) + } - fn remove(&self, key: &[u8]) -> Result<()> { - Ok(self.db.rocks.delete_cf(&self.cf(), key)?) - } + fn remove(&self, key: &[u8]) -> Result<()> { Ok(self.db.rocks.delete_cf(&self.cf(), key)?) } - fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { - Box::new( - self.db - .rocks - .iterator_cf(&self.cf(), rocksdb::IteratorMode::Start) - .map(std::result::Result::unwrap) - .map(|(k, v)| (Vec::from(k), Vec::from(v))), - ) - } + fn iter<'a>(&'a self) -> Box, Vec)> + 'a> { + Box::new( + self.db + .rocks + .iterator_cf(&self.cf(), rocksdb::IteratorMode::Start) + .map(std::result::Result::unwrap) + .map(|(k, v)| (Vec::from(k), Vec::from(v))), + ) + } - fn iter_from<'a>( - &'a self, - from: &[u8], - backwards: bool, - ) -> Box, Vec)> + 'a> { - Box::new( - self.db - .rocks - .iterator_cf( - &self.cf(), - rocksdb::IteratorMode::From( - from, - if backwards { - rocksdb::Direction::Reverse - } else { - rocksdb::Direction::Forward - }, - ), - ) - .map(std::result::Result::unwrap) - .map(|(k, v)| (Vec::from(k), Vec::from(v))), - ) - } + fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box, Vec)> + 'a> { + Box::new( + self.db + .rocks + .iterator_cf( + &self.cf(), + rocksdb::IteratorMode::From( + from, + if backwards { + rocksdb::Direction::Reverse + } else { + rocksdb::Direction::Forward + }, + ), + ) + .map(std::result::Result::unwrap) + .map(|(k, v)| (Vec::from(k), Vec::from(v))), + ) + } - fn increment(&self, key: &[u8]) -> Result> { - let lock = self.write_lock.write().unwrap(); + fn increment(&self, key: &[u8]) -> Result> { + let lock = self.write_lock.write().unwrap(); - let old = self.db.rocks.get_cf(&self.cf(), key)?; - let new = utils::increment(old.as_deref()).unwrap(); - self.db.rocks.put_cf(&self.cf(), key, &new)?; + let old = self.db.rocks.get_cf(&self.cf(), key)?; + let new = utils::increment(old.as_deref()).unwrap(); + self.db.rocks.put_cf(&self.cf(), key, &new)?; - drop(lock); - Ok(new) - } + drop(lock); + Ok(new) + } - fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()> { - let lock = self.write_lock.write().unwrap(); + fn increment_batch(&self, iter: &mut dyn Iterator>) -> Result<()> { + let lock = self.write_lock.write().unwrap(); - for key in iter { - let old = self.db.rocks.get_cf(&self.cf(), &key)?; - let new = utils::increment(old.as_deref()).unwrap(); - self.db.rocks.put_cf(&self.cf(), key, new)?; - } + for key in iter { + let old = self.db.rocks.get_cf(&self.cf(), &key)?; + let new = utils::increment(old.as_deref()).unwrap(); + self.db.rocks.put_cf(&self.cf(), key, new)?; + } - drop(lock); + drop(lock); - Ok(()) - } + Ok(()) + } - fn scan_prefix<'a>( - &'a self, - prefix: Vec, - ) -> Box, Vec)> + 'a> { - Box::new( - self.db - .rocks - .iterator_cf( - &self.cf(), - rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward), - ) - .map(std::result::Result::unwrap) - .map(|(k, v)| (Vec::from(k), Vec::from(v))) - .take_while(move |(k, _)| k.starts_with(&prefix)), - ) - } + fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box, Vec)> + 'a> { + Box::new( + self.db + .rocks + .iterator_cf(&self.cf(), rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward)) + .map(std::result::Result::unwrap) + .map(|(k, v)| (Vec::from(k), Vec::from(v))) + .take_while(move |(k, _)| k.starts_with(&prefix)), + ) + } - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { - self.watchers.watch(prefix) - } + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + self.watchers.watch(prefix) + } } diff --git a/src/database/abstraction/sqlite.rs b/src/database/abstraction/sqlite.rs index 7a79ac3e..6325d190 100644 --- a/src/database/abstraction/sqlite.rs +++ b/src/database/abstraction/sqlite.rs @@ -1,340 +1,305 @@ -use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; -use crate::{database::Config, Result}; +use std::{ + cell::RefCell, + future::Future, + path::{Path, PathBuf}, + pin::Pin, + sync::Arc, +}; + use parking_lot::{Mutex, MutexGuard}; use rusqlite::{Connection, DatabaseName::Main, OptionalExtension}; -use std::{ - cell::RefCell, - future::Future, - path::{Path, PathBuf}, - pin::Pin, - sync::Arc, -}; use thread_local::ThreadLocal; use tracing::debug; +use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree}; +use crate::{database::Config, Result}; + thread_local! { - static READ_CONNECTION: RefCell> = RefCell::new(None); - static READ_CONNECTION_ITERATOR: RefCell> = RefCell::new(None); + static READ_CONNECTION: RefCell> = RefCell::new(None); + static READ_CONNECTION_ITERATOR: RefCell> = RefCell::new(None); } struct PreparedStatementIterator<'a> { - pub iterator: Box + 'a>, - pub _statement_ref: NonAliasingBox>, + pub iterator: Box + 'a>, + pub _statement_ref: NonAliasingBox>, } impl Iterator for PreparedStatementIterator<'_> { - type Item = TupleOfBytes; + type Item = TupleOfBytes; - fn next(&mut self) -> Option { - self.iterator.next() - } + fn next(&mut self) -> Option { self.iterator.next() } } struct NonAliasingBox(*mut T); impl Drop for NonAliasingBox { - fn drop(&mut self) { - unsafe { - let _ = Box::from_raw(self.0); - }; - } + fn drop(&mut self) { + unsafe { + let _ = Box::from_raw(self.0); + }; + } } pub struct Engine { - writer: Mutex, - read_conn_tls: ThreadLocal, - read_iterator_conn_tls: ThreadLocal, + writer: Mutex, + read_conn_tls: ThreadLocal, + read_iterator_conn_tls: ThreadLocal, - path: PathBuf, - cache_size_per_thread: u32, + path: PathBuf, + cache_size_per_thread: u32, } impl Engine { - fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result { - let conn = Connection::open(path)?; + fn prepare_conn(path: &Path, cache_size_kb: u32) -> Result { + let conn = Connection::open(path)?; - conn.pragma_update(Some(Main), "page_size", 2048)?; - conn.pragma_update(Some(Main), "journal_mode", "WAL")?; - conn.pragma_update(Some(Main), "synchronous", "NORMAL")?; - conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?; - conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?; + conn.pragma_update(Some(Main), "page_size", 2048)?; + conn.pragma_update(Some(Main), "journal_mode", "WAL")?; + conn.pragma_update(Some(Main), "synchronous", "NORMAL")?; + conn.pragma_update(Some(Main), "cache_size", -i64::from(cache_size_kb))?; + conn.pragma_update(Some(Main), "wal_autocheckpoint", 0)?; - Ok(conn) - } + Ok(conn) + } - fn write_lock(&self) -> MutexGuard<'_, Connection> { - self.writer.lock() - } + fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() } - fn read_lock(&self) -> &Connection { - self.read_conn_tls - .get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) - } + fn read_lock(&self) -> &Connection { + self.read_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) + } - fn read_lock_iterator(&self) -> &Connection { - self.read_iterator_conn_tls - .get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) - } + fn read_lock_iterator(&self) -> &Connection { + self.read_iterator_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap()) + } - pub fn flush_wal(self: &Arc) -> Result<()> { - self.write_lock() - .pragma_update(Some(Main), "wal_checkpoint", "RESTART")?; - Ok(()) - } + pub fn flush_wal(self: &Arc) -> Result<()> { + self.write_lock().pragma_update(Some(Main), "wal_checkpoint", "RESTART")?; + Ok(()) + } } impl KeyValueDatabaseEngine for Arc { - fn open(config: &Config) -> Result { - let path = Path::new(&config.database_path).join("conduit.db"); + fn open(config: &Config) -> Result { + let path = Path::new(&config.database_path).join("conduit.db"); - // calculates cache-size per permanent connection - // 1. convert MB to KiB - // 2. divide by permanent connections + permanent iter connections + write connection - // 3. round down to nearest integer - let cache_size_per_thread: u32 = ((config.db_cache_capacity_mb * 1024.0) - / ((num_cpus::get().max(1) * 2) + 1) as f64) - as u32; + // calculates cache-size per permanent connection + // 1. convert MB to KiB + // 2. divide by permanent connections + permanent iter connections + write + // connection + // 3. round down to nearest integer + let cache_size_per_thread: u32 = + ((config.db_cache_capacity_mb * 1024.0) / ((num_cpus::get().max(1) * 2) + 1) as f64) as u32; - let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?); + let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?); - let arc = Arc::new(Engine { - writer, - read_conn_tls: ThreadLocal::new(), - read_iterator_conn_tls: ThreadLocal::new(), - path, - cache_size_per_thread, - }); + let arc = Arc::new(Engine { + writer, + read_conn_tls: ThreadLocal::new(), + read_iterator_conn_tls: ThreadLocal::new(), + path, + cache_size_per_thread, + }); - Ok(arc) - } + Ok(arc) + } - fn open_tree(&self, name: &str) -> Result> { - self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?; + fn open_tree(&self, name: &str) -> Result> { + self.write_lock().execute( + &format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), + [], + )?; - Ok(Arc::new(SqliteTable { - engine: Arc::clone(self), - name: name.to_owned(), - watchers: Watchers::default(), - })) - } + Ok(Arc::new(SqliteTable { + engine: Arc::clone(self), + name: name.to_owned(), + watchers: Watchers::default(), + })) + } - fn flush(&self) -> Result<()> { - // we enabled PRAGMA synchronous=normal, so this should not be necessary - Ok(()) - } + fn flush(&self) -> Result<()> { + // we enabled PRAGMA synchronous=normal, so this should not be necessary + Ok(()) + } - fn cleanup(&self) -> Result<()> { - self.flush_wal() - } + fn cleanup(&self) -> Result<()> { self.flush_wal() } } pub struct SqliteTable { - engine: Arc, - name: String, - watchers: Watchers, + engine: Arc, + name: String, + watchers: Watchers, } type TupleOfBytes = (Vec, Vec); impl SqliteTable { - fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result>> { - Ok(guard - .prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? - .query_row([key], |row| row.get(0)) - .optional()?) - } + fn get_with_guard(&self, guard: &Connection, key: &[u8]) -> Result>> { + Ok(guard + .prepare(format!("SELECT value FROM {} WHERE key = ?", self.name).as_str())? + .query_row([key], |row| row.get(0)) + .optional()?) + } - fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { - guard.execute( - format!( - "INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", - self.name - ) - .as_str(), - [key, value], - )?; - Ok(()) - } + fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> { + guard.execute( + format!("INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name).as_str(), + [key, value], + )?; + Ok(()) + } - pub fn iter_with_guard<'a>( - &'a self, - guard: &'a Connection, - ) -> Box + 'a> { - let statement = Box::leak(Box::new( - guard - .prepare(&format!( - "SELECT key, value FROM {} ORDER BY key ASC", - &self.name - )) - .unwrap(), - )); + pub fn iter_with_guard<'a>(&'a self, guard: &'a Connection) -> Box + 'a> { + let statement = Box::leak(Box::new( + guard.prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)).unwrap(), + )); - let statement_ref = NonAliasingBox(statement); + let statement_ref = NonAliasingBox(statement); - //let name = self.name.clone(); + //let name = self.name.clone(); - let iterator = Box::new( - statement - .query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(move |r| r.unwrap()), - ); + let iterator = Box::new( + statement.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))).unwrap().map(move |r| r.unwrap()), + ); - Box::new(PreparedStatementIterator { - iterator, - _statement_ref: statement_ref, - }) - } + Box::new(PreparedStatementIterator { + iterator, + _statement_ref: statement_ref, + }) + } } impl KvTree for SqliteTable { - fn get(&self, key: &[u8]) -> Result>> { - self.get_with_guard(self.engine.read_lock(), key) - } + fn get(&self, key: &[u8]) -> Result>> { self.get_with_guard(self.engine.read_lock(), key) } - fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { - let guard = self.engine.write_lock(); - self.insert_with_guard(&guard, key, value)?; - drop(guard); - self.watchers.wake(key); - Ok(()) - } + fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> { + let guard = self.engine.write_lock(); + self.insert_with_guard(&guard, key, value)?; + drop(guard); + self.watchers.wake(key); + Ok(()) + } - fn insert_batch<'a>(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { - let guard = self.engine.write_lock(); + fn insert_batch<'a>(&self, iter: &mut dyn Iterator, Vec)>) -> Result<()> { + let guard = self.engine.write_lock(); - guard.execute("BEGIN", [])?; - for (key, value) in iter { - self.insert_with_guard(&guard, &key, &value)?; - } - guard.execute("COMMIT", [])?; + guard.execute("BEGIN", [])?; + for (key, value) in iter { + self.insert_with_guard(&guard, &key, &value)?; + } + guard.execute("COMMIT", [])?; - drop(guard); + drop(guard); - Ok(()) - } + Ok(()) + } - fn increment_batch<'a>(&self, iter: &mut dyn Iterator>) -> Result<()> { - let guard = self.engine.write_lock(); + fn increment_batch<'a>(&self, iter: &mut dyn Iterator>) -> Result<()> { + let guard = self.engine.write_lock(); - guard.execute("BEGIN", [])?; - for key in iter { - let old = self.get_with_guard(&guard, &key)?; - let new = crate::utils::increment(old.as_deref()) - .expect("utils::increment always returns Some"); - self.insert_with_guard(&guard, &key, &new)?; - } - guard.execute("COMMIT", [])?; + guard.execute("BEGIN", [])?; + for key in iter { + let old = self.get_with_guard(&guard, &key)?; + let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some"); + self.insert_with_guard(&guard, &key, &new)?; + } + guard.execute("COMMIT", [])?; - drop(guard); + drop(guard); - Ok(()) - } + Ok(()) + } - fn remove(&self, key: &[u8]) -> Result<()> { - let guard = self.engine.write_lock(); + fn remove(&self, key: &[u8]) -> Result<()> { + let guard = self.engine.write_lock(); - guard.execute( - format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), - [key], - )?; + guard.execute(format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key])?; - Ok(()) - } + Ok(()) + } - fn iter<'a>(&'a self) -> Box + 'a> { - let guard = self.engine.read_lock_iterator(); + fn iter<'a>(&'a self) -> Box + 'a> { + let guard = self.engine.read_lock_iterator(); - self.iter_with_guard(guard) - } + self.iter_with_guard(guard) + } - fn iter_from<'a>( - &'a self, - from: &[u8], - backwards: bool, - ) -> Box + 'a> { - let guard = self.engine.read_lock_iterator(); - let from = from.to_vec(); // TODO change interface? + fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box + 'a> { + let guard = self.engine.read_lock_iterator(); + let from = from.to_vec(); // TODO change interface? - //let name = self.name.clone(); + //let name = self.name.clone(); - if backwards { - let statement = Box::leak(Box::new( - guard - .prepare(&format!( - "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", - &self.name - )) - .unwrap(), - )); + if backwards { + let statement = Box::leak(Box::new( + guard + .prepare(&format!( + "SELECT key, value FROM {} WHERE key <= ? ORDER BY key DESC", + &self.name + )) + .unwrap(), + )); - let statement_ref = NonAliasingBox(statement); + let statement_ref = NonAliasingBox(statement); - let iterator = Box::new( - statement - .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(move |r| r.unwrap()), - ); - Box::new(PreparedStatementIterator { - iterator, - _statement_ref: statement_ref, - }) - } else { - let statement = Box::leak(Box::new( - guard - .prepare(&format!( - "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", - &self.name - )) - .unwrap(), - )); + let iterator = Box::new( + statement + .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(move |r| r.unwrap()), + ); + Box::new(PreparedStatementIterator { + iterator, + _statement_ref: statement_ref, + }) + } else { + let statement = Box::leak(Box::new( + guard + .prepare(&format!( + "SELECT key, value FROM {} WHERE key >= ? ORDER BY key ASC", + &self.name + )) + .unwrap(), + )); - let statement_ref = NonAliasingBox(statement); + let statement_ref = NonAliasingBox(statement); - let iterator = Box::new( - statement - .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) - .unwrap() - .map(move |r| r.unwrap()), - ); + let iterator = Box::new( + statement + .query_map([from], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))) + .unwrap() + .map(move |r| r.unwrap()), + ); - Box::new(PreparedStatementIterator { - iterator, - _statement_ref: statement_ref, - }) - } - } + Box::new(PreparedStatementIterator { + iterator, + _statement_ref: statement_ref, + }) + } + } - fn increment(&self, key: &[u8]) -> Result> { - let guard = self.engine.write_lock(); + fn increment(&self, key: &[u8]) -> Result> { + let guard = self.engine.write_lock(); - let old = self.get_with_guard(&guard, key)?; + let old = self.get_with_guard(&guard, key)?; - let new = - crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some"); + let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some"); - self.insert_with_guard(&guard, key, &new)?; + self.insert_with_guard(&guard, key, &new)?; - Ok(new) - } + Ok(new) + } - fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { - Box::new( - self.iter_from(&prefix, false) - .take_while(move |(key, _)| key.starts_with(&prefix)), - ) - } + fn scan_prefix<'a>(&'a self, prefix: Vec) -> Box + 'a> { + Box::new(self.iter_from(&prefix, false).take_while(move |(key, _)| key.starts_with(&prefix))) + } - fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { - self.watchers.watch(prefix) - } + fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + self.watchers.watch(prefix) + } - fn clear(&self) -> Result<()> { - debug!("clear: running"); - self.engine - .write_lock() - .execute(format!("DELETE FROM {}", self.name).as_str(), [])?; - debug!("clear: ran"); - Ok(()) - } + fn clear(&self) -> Result<()> { + debug!("clear: running"); + self.engine.write_lock().execute(format!("DELETE FROM {}", self.name).as_str(), [])?; + debug!("clear: ran"); + Ok(()) + } } diff --git a/src/database/abstraction/watchers.rs b/src/database/abstraction/watchers.rs index 07087c1f..bbb306b7 100644 --- a/src/database/abstraction/watchers.rs +++ b/src/database/abstraction/watchers.rs @@ -1,56 +1,55 @@ use std::{ - collections::{hash_map, HashMap}, - future::Future, - pin::Pin, - sync::RwLock, + collections::{hash_map, HashMap}, + future::Future, + pin::Pin, + sync::RwLock, }; + use tokio::sync::watch; type Watcher = RwLock, (watch::Sender<()>, watch::Receiver<()>)>>; #[derive(Default)] pub(super) struct Watchers { - watchers: Watcher, + watchers: Watcher, } impl Watchers { - pub(super) fn watch<'a>( - &'a self, - prefix: &[u8], - ) -> Pin + Send + 'a>> { - let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { - hash_map::Entry::Occupied(o) => o.get().1.clone(), - hash_map::Entry::Vacant(v) => { - let (tx, rx) = tokio::sync::watch::channel(()); - v.insert((tx, rx.clone())); - rx - } - }; + pub(super) fn watch<'a>(&'a self, prefix: &[u8]) -> Pin + Send + 'a>> { + let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) { + hash_map::Entry::Occupied(o) => o.get().1.clone(), + hash_map::Entry::Vacant(v) => { + let (tx, rx) = tokio::sync::watch::channel(()); + v.insert((tx, rx.clone())); + rx + }, + }; - Box::pin(async move { - // Tx is never destroyed - rx.changed().await.unwrap(); - }) - } - pub(super) fn wake(&self, key: &[u8]) { - let watchers = self.watchers.read().unwrap(); - let mut triggered = Vec::new(); + Box::pin(async move { + // Tx is never destroyed + rx.changed().await.unwrap(); + }) + } - for length in 0..=key.len() { - if watchers.contains_key(&key[..length]) { - triggered.push(&key[..length]); - } - } + pub(super) fn wake(&self, key: &[u8]) { + let watchers = self.watchers.read().unwrap(); + let mut triggered = Vec::new(); - drop(watchers); + for length in 0..=key.len() { + if watchers.contains_key(&key[..length]) { + triggered.push(&key[..length]); + } + } - if !triggered.is_empty() { - let mut watchers = self.watchers.write().unwrap(); - for prefix in triggered { - if let Some(tx) = watchers.remove(prefix) { - let _ = tx.0.send(()); - } - } - }; - } + drop(watchers); + + if !triggered.is_empty() { + let mut watchers = self.watchers.write().unwrap(); + for prefix in triggered { + if let Some(tx) = watchers.remove(prefix) { + let _ = tx.0.send(()); + } + } + }; + } } diff --git a/src/database/key_value/account_data.rs b/src/database/key_value/account_data.rs index 4e9da595..90b033fb 100644 --- a/src/database/key_value/account_data.rs +++ b/src/database/key_value/account_data.rs @@ -1,148 +1,120 @@ use std::collections::HashMap; use ruma::{ - api::client::error::ErrorKind, - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, + api::client::error::ErrorKind, + events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, + serde::Raw, + RoomId, UserId, }; use tracing::warn; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; impl service::account_data::Data for KeyValueDatabase { - /// Places one event in the account data of the user and removes the previous entry. - #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - fn update( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - event_type: RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - let mut prefix = room_id - .map(std::string::ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xff); + /// Places one event in the account data of the user and removes the + /// previous entry. + #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] + fn update( + &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, + data: &serde_json::Value, + ) -> Result<()> { + let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); - let mut roomuserdataid = prefix.clone(); - roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - roomuserdataid.push(0xff); - roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); + let mut roomuserdataid = prefix.clone(); + roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + roomuserdataid.push(0xFF); + roomuserdataid.extend_from_slice(event_type.to_string().as_bytes()); - let mut key = prefix; - key.extend_from_slice(event_type.to_string().as_bytes()); + let mut key = prefix; + key.extend_from_slice(event_type.to_string().as_bytes()); - if data.get("type").is_none() || data.get("content").is_none() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Account data doesn't have all required fields.", - )); - } + if data.get("type").is_none() || data.get("content").is_none() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Account data doesn't have all required fields.", + )); + } - self.roomuserdataid_accountdata.insert( - &roomuserdataid, - &serde_json::to_vec(&data).expect("to_vec always works on json values"), - )?; + self.roomuserdataid_accountdata.insert( + &roomuserdataid, + &serde_json::to_vec(&data).expect("to_vec always works on json values"), + )?; - let prev = self.roomusertype_roomuserdataid.get(&key)?; + let prev = self.roomusertype_roomuserdataid.get(&key)?; - self.roomusertype_roomuserdataid - .insert(&key, &roomuserdataid)?; + self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; - // Remove old entry - if let Some(prev) = prev { - self.roomuserdataid_accountdata.remove(&prev)?; - } + // Remove old entry + if let Some(prev) = prev { + self.roomuserdataid_accountdata.remove(&prev)?; + } - Ok(()) - } + Ok(()) + } - /// Searches the account data for a specific kind. - #[tracing::instrument(skip(self, room_id, user_id, kind))] - fn get( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - kind: RoomAccountDataEventType, - ) -> Result>> { - let mut key = room_id - .map(std::string::ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(kind.to_string().as_bytes()); + /// Searches the account data for a specific kind. + #[tracing::instrument(skip(self, room_id, user_id, kind))] + fn get( + &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, + ) -> Result>> { + let mut key = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(kind.to_string().as_bytes()); - self.roomusertype_roomuserdataid - .get(&key)? - .and_then(|roomuserdataid| { - self.roomuserdataid_accountdata - .get(&roomuserdataid) - .transpose() - }) - .transpose()? - .map(|data| { - serde_json::from_slice(&data) - .map_err(|_| Error::bad_database("could not deserialize")) - }) - .transpose() - } + self.roomusertype_roomuserdataid + .get(&key)? + .and_then(|roomuserdataid| self.roomuserdataid_accountdata.get(&roomuserdataid).transpose()) + .transpose()? + .map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize"))) + .transpose() + } - /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip(self, room_id, user_id, since))] - fn changes_since( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - since: u64, - ) -> Result>> { - let mut userdata = HashMap::new(); + /// Returns all changes to the account data that happened after `since`. + #[tracing::instrument(skip(self, room_id, user_id, since))] + fn changes_since( + &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, + ) -> Result>> { + let mut userdata = HashMap::new(); - let mut prefix = room_id - .map(std::string::ToString::to_string) - .unwrap_or_default() - .as_bytes() - .to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(user_id.as_bytes()); - prefix.push(0xff); + let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(user_id.as_bytes()); + prefix.push(0xFF); - // Skip the data that's exactly at since, because we sent that last time - let mut first_possible = prefix.clone(); - first_possible.extend_from_slice(&(since + 1).to_be_bytes()); + // Skip the data that's exactly at since, because we sent that last time + let mut first_possible = prefix.clone(); + first_possible.extend_from_slice(&(since + 1).to_be_bytes()); - for r in self - .roomuserdataid_accountdata - .iter_from(&first_possible, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(k, v)| { - Ok::<_, Error>(( - RoomAccountDataEventType::from( - utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else( - || Error::bad_database("RoomUserData ID in db is invalid."), - )?) - .map_err(|e| { - warn!("RoomUserData ID in database is invalid: {}", e); - Error::bad_database("RoomUserData ID in db is invalid.") - })?, - ), - serde_json::from_slice::>(&v).map_err(|_| { - Error::bad_database("Database contains invalid account data.") - })?, - )) - }) - { - let (kind, data) = r?; - userdata.insert(kind, data); - } + for r in self + .roomuserdataid_accountdata + .iter_from(&first_possible, false) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(k, v)| { + Ok::<_, Error>(( + RoomAccountDataEventType::from( + utils::string_from_bytes( + k.rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?, + ) + .map_err(|e| { + warn!("RoomUserData ID in database is invalid: {}", e); + Error::bad_database("RoomUserData ID in db is invalid.") + })?, + ), + serde_json::from_slice::>(&v) + .map_err(|_| Error::bad_database("Database contains invalid account data."))?, + )) + }) { + let (kind, data) = r?; + userdata.insert(kind, data); + } - Ok(userdata) - } + Ok(userdata) + } } diff --git a/src/database/key_value/appservice.rs b/src/database/key_value/appservice.rs index 55206243..598fb9cd 100644 --- a/src/database/key_value/appservice.rs +++ b/src/database/key_value/appservice.rs @@ -3,78 +3,58 @@ use ruma::api::appservice::Registration; use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::appservice::Data for KeyValueDatabase { - /// Registers an appservice and returns the ID to the caller - fn register_appservice(&self, yaml: Registration) -> Result { - let id = yaml.id.as_str(); - self.id_appserviceregistrations.insert( - id.as_bytes(), - serde_yaml::to_string(&yaml).unwrap().as_bytes(), - )?; - self.cached_registrations - .write() - .unwrap() - .insert(id.to_owned(), yaml.clone()); + /// Registers an appservice and returns the ID to the caller + fn register_appservice(&self, yaml: Registration) -> Result { + let id = yaml.id.as_str(); + self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?; + self.cached_registrations.write().unwrap().insert(id.to_owned(), yaml.clone()); - Ok(id.to_owned()) - } + Ok(id.to_owned()) + } - /// Remove an appservice registration - /// - /// # Arguments - /// - /// * `service_name` - the name you send to register the service previously - fn unregister_appservice(&self, service_name: &str) -> Result<()> { - self.id_appserviceregistrations - .remove(service_name.as_bytes())?; - self.cached_registrations - .write() - .unwrap() - .remove(service_name); - Ok(()) - } + /// Remove an appservice registration + /// + /// # Arguments + /// + /// * `service_name` - the name you send to register the service previously + fn unregister_appservice(&self, service_name: &str) -> Result<()> { + self.id_appserviceregistrations.remove(service_name.as_bytes())?; + self.cached_registrations.write().unwrap().remove(service_name); + Ok(()) + } - fn get_registration(&self, id: &str) -> Result> { - self.cached_registrations - .read() - .unwrap() - .get(id) - .map_or_else( - || { - self.id_appserviceregistrations - .get(id.as_bytes())? - .map(|bytes| { - serde_yaml::from_slice(&bytes).map_err(|_| { - Error::bad_database( - "Invalid registration bytes in id_appserviceregistrations.", - ) - }) - }) - .transpose() - }, - |r| Ok(Some(r.clone())), - ) - } + fn get_registration(&self, id: &str) -> Result> { + self.cached_registrations.read().unwrap().get(id).map_or_else( + || { + self.id_appserviceregistrations + .get(id.as_bytes())? + .map(|bytes| { + serde_yaml::from_slice(&bytes).map_err(|_| { + Error::bad_database("Invalid registration bytes in id_appserviceregistrations.") + }) + }) + .transpose() + }, + |r| Ok(Some(r.clone())), + ) + } - fn iter_ids<'a>(&'a self) -> Result> + 'a>> { - Ok(Box::new(self.id_appserviceregistrations.iter().map( - |(id, _)| { - utils::string_from_bytes(&id).map_err(|_| { - Error::bad_database("Invalid id bytes in id_appserviceregistrations.") - }) - }, - ))) - } + fn iter_ids<'a>(&'a self) -> Result> + 'a>> { + Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| { + utils::string_from_bytes(&id) + .map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations.")) + }))) + } - fn all(&self) -> Result> { - self.iter_ids()? - .filter_map(std::result::Result::ok) - .map(move |id| { - Ok(( - id.clone(), - self.get_registration(&id)? - .expect("iter_ids only returns appservices that exist"), - )) - }) - .collect() - } + fn all(&self) -> Result> { + self.iter_ids()? + .filter_map(std::result::Result::ok) + .map(move |id| { + Ok(( + id.clone(), + self.get_registration(&id)?.expect("iter_ids only returns appservices that exist"), + )) + }) + .collect() + } } diff --git a/src/database/key_value/globals.rs b/src/database/key_value/globals.rs index b921d822..e24c6047 100644 --- a/src/database/key_value/globals.rs +++ b/src/database/key_value/globals.rs @@ -4,9 +4,9 @@ use async_trait::async_trait; use futures_util::{stream::FuturesUnordered, StreamExt}; use lru_cache::LruCache; use ruma::{ - api::federation::discovery::{ServerSigningKeys, VerifyKey}, - signatures::Ed25519KeyPair, - DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, + api::federation::discovery::{ServerSigningKeys, VerifyKey}, + signatures::Ed25519KeyPair, + DeviceId, MilliSecondsSinceUnixEpoch, OwnedServerSigningKeyId, ServerName, UserId, }; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; @@ -16,139 +16,118 @@ const LAST_CHECK_FOR_UPDATES_COUNT: &[u8] = b"u"; #[async_trait] impl service::globals::Data for KeyValueDatabase { - fn next_count(&self) -> Result { - utils::u64_from_bytes(&self.global.increment(COUNTER)?) - .map_err(|_| Error::bad_database("Count has invalid bytes.")) - } + fn next_count(&self) -> Result { + utils::u64_from_bytes(&self.global.increment(COUNTER)?) + .map_err(|_| Error::bad_database("Count has invalid bytes.")) + } - fn current_count(&self) -> Result { - self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Count has invalid bytes.")) - }) - } + fn current_count(&self) -> Result { + self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes.")) + }) + } - fn last_check_for_updates_id(&self) -> Result { - self.global - .get(LAST_CHECK_FOR_UPDATES_COUNT)? - .map_or(Ok(0_u64), |bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("last check for updates count has invalid bytes.") - }) - }) - } + fn last_check_for_updates_id(&self) -> Result { + self.global.get(LAST_CHECK_FOR_UPDATES_COUNT)?.map_or(Ok(0_u64), |bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("last check for updates count has invalid bytes.")) + }) + } - fn update_check_for_updates_id(&self, id: u64) -> Result<()> { - self.global - .insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; + fn update_check_for_updates_id(&self, id: u64) -> Result<()> { + self.global.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?; - Ok(()) - } + Ok(()) + } - async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let userid_bytes = user_id.as_bytes().to_vec(); - let mut userid_prefix = userid_bytes.clone(); - userid_prefix.push(0xff); + async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + let userid_bytes = user_id.as_bytes().to_vec(); + let mut userid_prefix = userid_bytes.clone(); + userid_prefix.push(0xFF); - let mut userdeviceid_prefix = userid_prefix.clone(); - userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); - userdeviceid_prefix.push(0xff); + let mut userdeviceid_prefix = userid_prefix.clone(); + userdeviceid_prefix.extend_from_slice(device_id.as_bytes()); + userdeviceid_prefix.push(0xFF); - let mut futures = FuturesUnordered::new(); + let mut futures = FuturesUnordered::new(); - // Return when *any* user changed his key - // TODO: only send for user they share a room with - futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix)); + // Return when *any* user changed his key + // TODO: only send for user they share a room with + futures.push(self.todeviceid_events.watch_prefix(&userdeviceid_prefix)); - futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); - futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); - futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); - futures.push( - self.userroomid_notificationcount - .watch_prefix(&userid_prefix), - ); - futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_joined.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_notificationcount.watch_prefix(&userid_prefix)); + futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix)); - // Events for rooms we are in - for room_id in services() - .rooms - .state_cache - .rooms_joined(user_id) - .filter_map(std::result::Result::ok) - { - let short_roomid = services() - .rooms - .short - .get_shortroomid(&room_id) - .ok() - .flatten() - .expect("room exists") - .to_be_bytes() - .to_vec(); + // Events for rooms we are in + for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(std::result::Result::ok) { + let short_roomid = services() + .rooms + .short + .get_shortroomid(&room_id) + .ok() + .flatten() + .expect("room exists") + .to_be_bytes() + .to_vec(); - let roomid_bytes = room_id.as_bytes().to_vec(); - let mut roomid_prefix = roomid_bytes.clone(); - roomid_prefix.push(0xff); + let roomid_bytes = room_id.as_bytes().to_vec(); + let mut roomid_prefix = roomid_bytes.clone(); + roomid_prefix.push(0xFF); - // PDUs - futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); + // PDUs + futures.push(self.pduid_pdu.watch_prefix(&short_roomid)); - // EDUs - futures.push(self.roomid_lasttypingupdate.watch_prefix(&roomid_bytes)); + // EDUs + futures.push(self.roomid_lasttypingupdate.watch_prefix(&roomid_bytes)); - futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); + futures.push(self.readreceiptid_readreceipt.watch_prefix(&roomid_prefix)); - // Key changes - futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); + // Key changes + futures.push(self.keychangeid_userid.watch_prefix(&roomid_prefix)); - // Room account data - let mut roomuser_prefix = roomid_prefix.clone(); - roomuser_prefix.extend_from_slice(&userid_prefix); + // Room account data + let mut roomuser_prefix = roomid_prefix.clone(); + roomuser_prefix.extend_from_slice(&userid_prefix); - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&roomuser_prefix), - ); - } + futures.push(self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix)); + } - let mut globaluserdata_prefix = vec![0xff]; - globaluserdata_prefix.extend_from_slice(&userid_prefix); + let mut globaluserdata_prefix = vec![0xFF]; + globaluserdata_prefix.extend_from_slice(&userid_prefix); - futures.push( - self.roomusertype_roomuserdataid - .watch_prefix(&globaluserdata_prefix), - ); + futures.push(self.roomusertype_roomuserdataid.watch_prefix(&globaluserdata_prefix)); - // More key changes (used when user is not joined to any rooms) - futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); + // More key changes (used when user is not joined to any rooms) + futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix)); - // One time keys - futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); + // One time keys + futures.push(self.userid_lastonetimekeyupdate.watch_prefix(&userid_bytes)); - futures.push(Box::pin(services().globals.rotate.watch())); + futures.push(Box::pin(services().globals.rotate.watch())); - // Wait until one of them finds something - futures.next().await; + // Wait until one of them finds something + futures.next().await; - Ok(()) - } + Ok(()) + } - fn cleanup(&self) -> Result<()> { - self.db.cleanup() - } + fn cleanup(&self) -> Result<()> { self.db.cleanup() } - fn memory_usage(&self) -> String { - let pdu_cache = self.pdu_cache.lock().unwrap().len(); - let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len(); - let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); - let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len(); - let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len(); - let our_real_users_cache = self.our_real_users_cache.read().unwrap().len(); - let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len(); - let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len(); + fn memory_usage(&self) -> String { + let pdu_cache = self.pdu_cache.lock().unwrap().len(); + let shorteventid_cache = self.shorteventid_cache.lock().unwrap().len(); + let auth_chain_cache = self.auth_chain_cache.lock().unwrap().len(); + let eventidshort_cache = self.eventidshort_cache.lock().unwrap().len(); + let statekeyshort_cache = self.statekeyshort_cache.lock().unwrap().len(); + let our_real_users_cache = self.our_real_users_cache.read().unwrap().len(); + let appservice_in_room_cache = self.appservice_in_room_cache.read().unwrap().len(); + let lasttimelinecount_cache = self.lasttimelinecount_cache.lock().unwrap().len(); - let mut response = format!( - "\ + let mut response = format!( + "\ pdu_cache: {pdu_cache} shorteventid_cache: {shorteventid_cache} auth_chain_cache: {auth_chain_cache} @@ -157,155 +136,137 @@ statekeyshort_cache: {statekeyshort_cache} our_real_users_cache: {our_real_users_cache} appservice_in_room_cache: {appservice_in_room_cache} lasttimelinecount_cache: {lasttimelinecount_cache}\n" - ); - if let Ok(db_stats) = self.db.memory_usage() { - response += &db_stats; - } + ); + if let Ok(db_stats) = self.db.memory_usage() { + response += &db_stats; + } - response - } + response + } - fn clear_caches(&self, amount: u32) { - if amount > 0 { - let c = &mut *self.pdu_cache.lock().unwrap(); - *c = LruCache::new(c.capacity()); - } - if amount > 1 { - let c = &mut *self.shorteventid_cache.lock().unwrap(); - *c = LruCache::new(c.capacity()); - } - if amount > 2 { - let c = &mut *self.auth_chain_cache.lock().unwrap(); - *c = LruCache::new(c.capacity()); - } - if amount > 3 { - let c = &mut *self.eventidshort_cache.lock().unwrap(); - *c = LruCache::new(c.capacity()); - } - if amount > 4 { - let c = &mut *self.statekeyshort_cache.lock().unwrap(); - *c = LruCache::new(c.capacity()); - } - if amount > 5 { - let c = &mut *self.our_real_users_cache.write().unwrap(); - *c = HashMap::new(); - } - if amount > 6 { - let c = &mut *self.appservice_in_room_cache.write().unwrap(); - *c = HashMap::new(); - } - if amount > 7 { - let c = &mut *self.lasttimelinecount_cache.lock().unwrap(); - *c = HashMap::new(); - } - } + fn clear_caches(&self, amount: u32) { + if amount > 0 { + let c = &mut *self.pdu_cache.lock().unwrap(); + *c = LruCache::new(c.capacity()); + } + if amount > 1 { + let c = &mut *self.shorteventid_cache.lock().unwrap(); + *c = LruCache::new(c.capacity()); + } + if amount > 2 { + let c = &mut *self.auth_chain_cache.lock().unwrap(); + *c = LruCache::new(c.capacity()); + } + if amount > 3 { + let c = &mut *self.eventidshort_cache.lock().unwrap(); + *c = LruCache::new(c.capacity()); + } + if amount > 4 { + let c = &mut *self.statekeyshort_cache.lock().unwrap(); + *c = LruCache::new(c.capacity()); + } + if amount > 5 { + let c = &mut *self.our_real_users_cache.write().unwrap(); + *c = HashMap::new(); + } + if amount > 6 { + let c = &mut *self.appservice_in_room_cache.write().unwrap(); + *c = HashMap::new(); + } + if amount > 7 { + let c = &mut *self.lasttimelinecount_cache.lock().unwrap(); + *c = HashMap::new(); + } + } - fn load_keypair(&self) -> Result { - let keypair_bytes = self.global.get(b"keypair")?.map_or_else( - || { - let keypair = utils::generate_keypair(); - self.global.insert(b"keypair", &keypair)?; - Ok::<_, Error>(keypair) - }, - Ok, - )?; + fn load_keypair(&self) -> Result { + let keypair_bytes = self.global.get(b"keypair")?.map_or_else( + || { + let keypair = utils::generate_keypair(); + self.global.insert(b"keypair", &keypair)?; + Ok::<_, Error>(keypair) + }, + Ok, + )?; - let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff); + let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF); - utils::string_from_bytes( - // 1. version - parts - .next() - .expect("splitn always returns at least one element"), - ) - .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) - .and_then(|version| { - // 2. key - parts - .next() - .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) - .map(|key| (version, key)) - }) - .and_then(|(version, key)| { - Ed25519KeyPair::from_der(key, version) - .map_err(|_| Error::bad_database("Private or public keys are invalid.")) - }) - } - fn remove_keypair(&self) -> Result<()> { - self.global.remove(b"keypair") - } + utils::string_from_bytes( + // 1. version + parts.next().expect("splitn always returns at least one element"), + ) + .map_err(|_| Error::bad_database("Invalid version bytes in keypair.")) + .and_then(|version| { + // 2. key + parts + .next() + .ok_or_else(|| Error::bad_database("Invalid keypair format in database.")) + .map(|key| (version, key)) + }) + .and_then(|(version, key)| { + Ed25519KeyPair::from_der(key, version) + .map_err(|_| Error::bad_database("Private or public keys are invalid.")) + }) + } - fn add_signing_key( - &self, - origin: &ServerName, - new_keys: ServerSigningKeys, - ) -> Result> { - // Not atomic, but this is not critical - let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; + fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") } - let mut keys = signingkeys - .and_then(|keys| serde_json::from_slice(&keys).ok()) - .unwrap_or_else(|| { - // Just insert "now", it doesn't matter - ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) - }); + fn add_signing_key( + &self, origin: &ServerName, new_keys: ServerSigningKeys, + ) -> Result> { + // Not atomic, but this is not critical + let signingkeys = self.server_signingkeys.get(origin.as_bytes())?; - let ServerSigningKeys { - verify_keys, - old_verify_keys, - .. - } = new_keys; + let mut keys = signingkeys.and_then(|keys| serde_json::from_slice(&keys).ok()).unwrap_or_else(|| { + // Just insert "now", it doesn't matter + ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now()) + }); - keys.verify_keys.extend(verify_keys); - keys.old_verify_keys.extend(old_verify_keys); + let ServerSigningKeys { + verify_keys, + old_verify_keys, + .. + } = new_keys; - self.server_signingkeys.insert( - origin.as_bytes(), - &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), - )?; + keys.verify_keys.extend(verify_keys); + keys.old_verify_keys.extend(old_verify_keys); - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); + self.server_signingkeys.insert( + origin.as_bytes(), + &serde_json::to_vec(&keys).expect("serversigningkeys can be serialized"), + )?; - Ok(tree) - } + let mut tree = keys.verify_keys; + tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key)))); - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. - fn signing_keys_for( - &self, - origin: &ServerName, - ) -> Result> { - let signingkeys = self - .server_signingkeys - .get(origin.as_bytes())? - .and_then(|bytes| serde_json::from_slice(&bytes).ok()) - .map(|keys: ServerSigningKeys| { - let mut tree = keys.verify_keys; - tree.extend( - keys.old_verify_keys - .into_iter() - .map(|old| (old.0, VerifyKey::new(old.1.key))), - ); - tree - }) - .unwrap_or_else(BTreeMap::new); + Ok(tree) + } - Ok(signingkeys) - } + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found + /// for the server. + fn signing_keys_for(&self, origin: &ServerName) -> Result> { + let signingkeys = self + .server_signingkeys + .get(origin.as_bytes())? + .and_then(|bytes| serde_json::from_slice(&bytes).ok()) + .map(|keys: ServerSigningKeys| { + let mut tree = keys.verify_keys; + tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key)))); + tree + }) + .unwrap_or_else(BTreeMap::new); - fn database_version(&self) -> Result { - self.global.get(b"version")?.map_or(Ok(0), |version| { - utils::u64_from_bytes(&version) - .map_err(|_| Error::bad_database("Database version id is invalid.")) - }) - } + Ok(signingkeys) + } - fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.global.insert(b"version", &new_version.to_be_bytes())?; - Ok(()) - } + fn database_version(&self) -> Result { + self.global.get(b"version")?.map_or(Ok(0), |version| { + utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid.")) + }) + } + + fn bump_database_version(&self, new_version: u64) -> Result<()> { + self.global.insert(b"version", &new_version.to_be_bytes())?; + Ok(()) + } } diff --git a/src/database/key_value/key_backups.rs b/src/database/key_value/key_backups.rs index a8018938..5de65eaa 100644 --- a/src/database/key_value/key_backups.rs +++ b/src/database/key_value/key_backups.rs @@ -1,364 +1,292 @@ use std::collections::BTreeMap; use ruma::{ - api::client::{ - backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, - error::ErrorKind, - }, - serde::Raw, - OwnedRoomId, RoomId, UserId, + api::client::{ + backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, + error::ErrorKind, + }, + serde::Raw, + OwnedRoomId, RoomId, UserId, }; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; impl service::key_backups::Data for KeyValueDatabase { - fn create_backup( - &self, - user_id: &UserId, - backup_metadata: &Raw, - ) -> Result { - let version = services().globals.next_count()?.to_string(); + fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { + let version = services().globals.next_count()?.to_string(); - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(version.as_bytes()); + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - self.backupid_algorithm.insert( - &key, - &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), - )?; - self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; - Ok(version) - } + self.backupid_algorithm.insert( + &key, + &serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"), + )?; + self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?; + Ok(version) + } - fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(version.as_bytes()); + fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - self.backupid_algorithm.remove(&key)?; - self.backupid_etag.remove(&key)?; + self.backupid_algorithm.remove(&key)?; + self.backupid_etag.remove(&key)?; - key.push(0xff); + key.push(0xFF); - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } - Ok(()) - } + Ok(()) + } - fn update_backup( - &self, - user_id: &UserId, - version: &str, - backup_metadata: &Raw, - ) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(version.as_bytes()); + fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw) -> Result { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Tried to update nonexistent backup.", - )); - } + if self.backupid_algorithm.get(&key)?.is_none() { + return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); + } - self.backupid_algorithm - .insert(&key, backup_metadata.json().get().as_bytes())?; - self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; - Ok(version.to_owned()) - } + self.backupid_algorithm.insert(&key, backup_metadata.json().get().as_bytes())?; + self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?; + Ok(version.to_owned()) + } - fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, _)| { - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) - }) - .transpose() - } + self.backupid_algorithm + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|(key, _)| { + utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid.")) + }) + .transpose() + } - fn get_latest_backup( - &self, - user_id: &UserId, - ) -> Result)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - self.backupid_algorithm - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|(key, value)| { - let version = utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; + self.backupid_algorithm + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|(key, value)| { + let version = utils::string_from_bytes( + key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"), + ) + .map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?; - Ok(( - version, - serde_json::from_slice(&value).map_err(|_| { - Error::bad_database("Algorithm in backupid_algorithm is invalid.") - })?, - )) - }) - .transpose() - } + Ok(( + version, + serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?, + )) + }) + .transpose() + } - fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(version.as_bytes()); + fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - self.backupid_algorithm - .get(&key)? - .map_or(Ok(None), |bytes| { - serde_json::from_slice(&bytes) - .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) - }) - } + self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| { + serde_json::from_slice(&bytes) + .map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid.")) + }) + } - fn add_key( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - session_id: &str, - key_data: &Raw, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(version.as_bytes()); + fn add_key( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - if self.backupid_algorithm.get(&key)?.is_none() { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Tried to update nonexistent backup.", - )); - } + if self.backupid_algorithm.get(&key)?.is_none() { + return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup.")); + } - self.backupid_etag - .insert(&key, &services().globals.next_count()?.to_be_bytes())?; + self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?; - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(session_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(session_id.as_bytes()); - self.backupkeyid_backup - .insert(&key, key_data.json().get().as_bytes())?; + self.backupkeyid_backup.insert(&key, key_data.json().get().as_bytes())?; - Ok(()) - } + Ok(()) + } - fn count_keys(&self, user_id: &UserId, version: &str) -> Result { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(version.as_bytes()); + fn count_keys(&self, user_id: &UserId, version: &str) -> Result { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(version.as_bytes()); - Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) - } + Ok(self.backupkeyid_backup.scan_prefix(prefix).count()) + } - fn get_etag(&self, user_id: &UserId, version: &str) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(version.as_bytes()); + fn get_etag(&self, user_id: &UserId, version: &str) -> Result { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); - Ok(utils::u64_from_bytes( - &self - .backupid_etag - .get(&key)? - .ok_or_else(|| Error::bad_database("Backup has no etag."))?, - ) - .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? - .to_string()) - } + Ok(utils::u64_from_bytes( + &self.backupid_etag.get(&key)?.ok_or_else(|| Error::bad_database("Backup has no etag."))?, + ) + .map_err(|_| Error::bad_database("etag in backupid_etag invalid."))? + .to_string()) + } - fn get_all( - &self, - user_id: &UserId, - version: &str, - ) -> Result> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xff); + fn get_all(&self, user_id: &UserId, version: &str) -> Result> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(version.as_bytes()); + prefix.push(0xFF); - let mut rooms = BTreeMap::::new(); + let mut rooms = BTreeMap::::new(); - for result in self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xff); + for result in self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| { + let mut parts = key.rsplit(|&b| b == 0xFF); - let session_id = - utils::string_from_bytes(parts.next().ok_or_else(|| { - Error::bad_database("backupkeyid_backup key is invalid.") - })?) - .map_err(|_| { - Error::bad_database("backupkeyid_backup session_id is invalid.") - })?; + let session_id = utils::string_from_bytes( + parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, + ) + .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - let room_id = RoomId::parse( - utils::string_from_bytes(parts.next().ok_or_else(|| { - Error::bad_database("backupkeyid_backup key is invalid.") - })?) - .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, - ) - .map_err(|_| { - Error::bad_database("backupkeyid_backup room_id is invalid room id.") - })?; + let room_id = RoomId::parse( + utils::string_from_bytes( + parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, + ) + .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?, + ) + .map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?; - let key_data = serde_json::from_slice(&value).map_err(|_| { - Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") - })?; + let key_data = serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - Ok::<_, Error>((room_id, session_id, key_data)) - }) - { - let (room_id, session_id, key_data) = result?; - rooms - .entry(room_id) - .or_insert_with(|| RoomKeyBackup { - sessions: BTreeMap::new(), - }) - .sessions - .insert(session_id, key_data); - } + Ok::<_, Error>((room_id, session_id, key_data)) + }) { + let (room_id, session_id, key_data) = result?; + rooms + .entry(room_id) + .or_insert_with(|| RoomKeyBackup { + sessions: BTreeMap::new(), + }) + .sessions + .insert(session_id, key_data); + } - Ok(rooms) - } + Ok(rooms) + } - fn get_room( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - ) -> Result>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(version.as_bytes()); - prefix.push(0xff); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); + fn get_room( + &self, user_id: &UserId, version: &str, room_id: &RoomId, + ) -> Result>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(version.as_bytes()); + prefix.push(0xFF); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xFF); - Ok(self - .backupkeyid_backup - .scan_prefix(prefix) - .map(|(key, value)| { - let mut parts = key.rsplit(|&b| b == 0xff); + Ok(self + .backupkeyid_backup + .scan_prefix(prefix) + .map(|(key, value)| { + let mut parts = key.rsplit(|&b| b == 0xFF); - let session_id = - utils::string_from_bytes(parts.next().ok_or_else(|| { - Error::bad_database("backupkeyid_backup key is invalid.") - })?) - .map_err(|_| { - Error::bad_database("backupkeyid_backup session_id is invalid.") - })?; + let session_id = utils::string_from_bytes( + parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?, + ) + .map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?; - let key_data = serde_json::from_slice(&value).map_err(|_| { - Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") - })?; + let key_data = serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?; - Ok::<_, Error>((session_id, key_data)) - }) - .filter_map(std::result::Result::ok) - .collect()) - } + Ok::<_, Error>((session_id, key_data)) + }) + .filter_map(std::result::Result::ok) + .collect()) + } - fn get_session( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - session_id: &str, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(version.as_bytes()); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(session_id.as_bytes()); + fn get_session( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, + ) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(session_id.as_bytes()); - self.backupkeyid_backup - .get(&key)? - .map(|value| { - serde_json::from_slice(&value).map_err(|_| { - Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.") - }) - }) - .transpose() - } + self.backupkeyid_backup + .get(&key)? + .map(|value| { + serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")) + }) + .transpose() + } - fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(version.as_bytes()); - key.push(0xff); + fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + key.push(0xFF); - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } - Ok(()) - } + Ok(()) + } - fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(version.as_bytes()); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); + fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } - Ok(()) - } + Ok(()) + } - fn delete_room_key( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - session_id: &str, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(version.as_bytes()); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(session_id.as_bytes()); + fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(version.as_bytes()); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(session_id.as_bytes()); - for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { - self.backupkeyid_backup.remove(&outdated_key)?; - } + for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) { + self.backupkeyid_backup.remove(&outdated_key)?; + } - Ok(()) - } + Ok(()) + } } diff --git a/src/database/key_value/media.rs b/src/database/key_value/media.rs index 1d879e72..af16d1f6 100644 --- a/src/database/key_value/media.rs +++ b/src/database/key_value/media.rs @@ -2,245 +2,182 @@ use ruma::api::client::error::ErrorKind; use tracing::debug; use crate::{ - database::KeyValueDatabase, - service::{self, media::UrlPreviewData}, - utils, Error, Result, + database::KeyValueDatabase, + service::{self, media::UrlPreviewData}, + utils, Error, Result, }; impl service::media::Data for KeyValueDatabase { - fn create_file_metadata( - &self, - mxc: String, - width: u32, - height: u32, - content_disposition: Option<&str>, - content_type: Option<&str>, - ) -> Result> { - let mut key = mxc.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&width.to_be_bytes()); - key.extend_from_slice(&height.to_be_bytes()); - key.push(0xff); - key.extend_from_slice( - content_disposition - .as_ref() - .map(|f| f.as_bytes()) - .unwrap_or_default(), - ); - key.push(0xff); - key.extend_from_slice( - content_type - .as_ref() - .map(|c| c.as_bytes()) - .unwrap_or_default(), - ); + fn create_file_metadata( + &self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>, + ) -> Result> { + let mut key = mxc.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&width.to_be_bytes()); + key.extend_from_slice(&height.to_be_bytes()); + key.push(0xFF); + key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default()); + key.push(0xFF); + key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default()); - self.mediaid_file.insert(&key, &[])?; + self.mediaid_file.insert(&key, &[])?; - Ok(key) - } + Ok(key) + } - fn delete_file_mxc(&self, mxc: String) -> Result<()> { - debug!("MXC URI: {:?}", mxc); + fn delete_file_mxc(&self, mxc: String) -> Result<()> { + debug!("MXC URI: {:?}", mxc); - let mut prefix = mxc.as_bytes().to_vec(); - prefix.push(0xff); + let mut prefix = mxc.as_bytes().to_vec(); + prefix.push(0xFF); - debug!("MXC db prefix: {:?}", prefix); + debug!("MXC db prefix: {:?}", prefix); - for (key, _) in self.mediaid_file.scan_prefix(prefix) { - debug!("Deleting key: {:?}", key); - self.mediaid_file.remove(&key)?; - } + for (key, _) in self.mediaid_file.scan_prefix(prefix) { + debug!("Deleting key: {:?}", key); + self.mediaid_file.remove(&key)?; + } - Ok(()) - } + Ok(()) + } - /// Searches for all files with the given MXC - fn search_mxc_metadata_prefix(&self, mxc: String) -> Result>> { - debug!("MXC URI: {:?}", mxc); + /// Searches for all files with the given MXC + fn search_mxc_metadata_prefix(&self, mxc: String) -> Result>> { + debug!("MXC URI: {:?}", mxc); - let mut prefix = mxc.as_bytes().to_vec(); - prefix.push(0xff); + let mut prefix = mxc.as_bytes().to_vec(); + prefix.push(0xFF); - let mut keys: Vec> = vec![]; + let mut keys: Vec> = vec![]; - for (key, _) in self.mediaid_file.scan_prefix(prefix) { - keys.push(key); - } + for (key, _) in self.mediaid_file.scan_prefix(prefix) { + keys.push(key); + } - if keys.is_empty() { - return Err(Error::bad_database( - "Failed to find any keys in database with the provided MXC.", - )); - } + if keys.is_empty() { + return Err(Error::bad_database( + "Failed to find any keys in database with the provided MXC.", + )); + } - debug!("Got the following keys: {:?}", keys); + debug!("Got the following keys: {:?}", keys); - Ok(keys) - } + Ok(keys) + } - fn search_file_metadata( - &self, - mxc: String, - width: u32, - height: u32, - ) -> Result<(Option, Option, Vec)> { - let mut prefix = mxc.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(&width.to_be_bytes()); - prefix.extend_from_slice(&height.to_be_bytes()); - prefix.push(0xff); + fn search_file_metadata( + &self, mxc: String, width: u32, height: u32, + ) -> Result<(Option, Option, Vec)> { + let mut prefix = mxc.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(&width.to_be_bytes()); + prefix.extend_from_slice(&height.to_be_bytes()); + prefix.push(0xFF); - let (key, _) = self - .mediaid_file - .scan_prefix(prefix) - .next() - .ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; + let (key, _) = self + .mediaid_file + .scan_prefix(prefix) + .next() + .ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?; - let mut parts = key.rsplit(|&b| b == 0xff); + let mut parts = key.rsplit(|&b| b == 0xFF); - let content_type = parts - .next() - .map(|bytes| { - utils::string_from_bytes(bytes).map_err(|_| { - Error::bad_database("Content type in mediaid_file is invalid unicode.") - }) - }) - .transpose()?; + let content_type = parts + .next() + .map(|bytes| { + utils::string_from_bytes(bytes) + .map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode.")) + }) + .transpose()?; - let content_disposition_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; + let content_disposition_bytes = + parts.next().ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?; - let content_disposition = if content_disposition_bytes.is_empty() { - None - } else { - Some( - utils::string_from_bytes(content_disposition_bytes).map_err(|_| { - Error::bad_database("Content Disposition in mediaid_file is invalid unicode.") - })?, - ) - }; - Ok((content_disposition, content_type, key)) - } + let content_disposition = if content_disposition_bytes.is_empty() { + None + } else { + Some( + utils::string_from_bytes(content_disposition_bytes) + .map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?, + ) + }; + Ok((content_disposition, content_type, key)) + } - /// Gets all the media keys in our database (this includes all the metadata associated with it such as width, height, content-type, etc) - fn get_all_media_keys(&self) -> Result>> { - let mut keys: Vec> = vec![]; + /// Gets all the media keys in our database (this includes all the metadata + /// associated with it such as width, height, content-type, etc) + fn get_all_media_keys(&self) -> Result>> { + let mut keys: Vec> = vec![]; - for (key, _) in self.mediaid_file.iter() { - keys.push(key); - } + for (key, _) in self.mediaid_file.iter() { + keys.push(key); + } - Ok(keys) - } + Ok(keys) + } - fn remove_url_preview(&self, url: &str) -> Result<()> { - self.url_previews.remove(url.as_bytes()) - } + fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) } - fn set_url_preview( - &self, - url: &str, - data: &UrlPreviewData, - timestamp: std::time::Duration, - ) -> Result<()> { - let mut value = Vec::::new(); - value.extend_from_slice(×tamp.as_secs().to_be_bytes()); - value.push(0xff); - value.extend_from_slice( - data.title - .as_ref() - .map(std::string::String::as_bytes) - .unwrap_or_default(), - ); - value.push(0xff); - value.extend_from_slice( - data.description - .as_ref() - .map(std::string::String::as_bytes) - .unwrap_or_default(), - ); - value.push(0xff); - value.extend_from_slice( - data.image - .as_ref() - .map(std::string::String::as_bytes) - .unwrap_or_default(), - ); - value.push(0xff); - value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes()); - value.push(0xff); - value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes()); - value.push(0xff); - value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes()); + fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> { + let mut value = Vec::::new(); + value.extend_from_slice(×tamp.as_secs().to_be_bytes()); + value.push(0xFF); + value.extend_from_slice(data.title.as_ref().map(std::string::String::as_bytes).unwrap_or_default()); + value.push(0xFF); + value.extend_from_slice(data.description.as_ref().map(std::string::String::as_bytes).unwrap_or_default()); + value.push(0xFF); + value.extend_from_slice(data.image.as_ref().map(std::string::String::as_bytes).unwrap_or_default()); + value.push(0xFF); + value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes()); + value.push(0xFF); + value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes()); + value.push(0xFF); + value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes()); - self.url_previews.insert(url.as_bytes(), &value) - } + self.url_previews.insert(url.as_bytes(), &value) + } - fn get_url_preview(&self, url: &str) -> Option { - let values = self.url_previews.get(url.as_bytes()).ok()??; + fn get_url_preview(&self, url: &str) -> Option { + let values = self.url_previews.get(url.as_bytes()).ok()??; - let mut values = values.split(|&b| b == 0xff); + let mut values = values.split(|&b| b == 0xFF); - let _ts = match values - .next() - .map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) - { - Some(0) => None, - x => x, - }; - let title = match values - .next() - .and_then(|b| String::from_utf8(b.to_vec()).ok()) - { - Some(s) if s.is_empty() => None, - x => x, - }; - let description = match values - .next() - .and_then(|b| String::from_utf8(b.to_vec()).ok()) - { - Some(s) if s.is_empty() => None, - x => x, - }; - let image = match values - .next() - .and_then(|b| String::from_utf8(b.to_vec()).ok()) - { - Some(s) if s.is_empty() => None, - x => x, - }; - let image_size = match values - .next() - .map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array"))) - { - Some(0) => None, - x => x, - }; - let image_width = match values - .next() - .map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) - { - Some(0) => None, - x => x, - }; - let image_height = match values - .next() - .map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) - { - Some(0) => None, - x => x, - }; + let _ts = match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) { + Some(0) => None, + x => x, + }; + let title = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) { + Some(s) if s.is_empty() => None, + x => x, + }; + let description = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) { + Some(s) if s.is_empty() => None, + x => x, + }; + let image = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) { + Some(s) if s.is_empty() => None, + x => x, + }; + let image_size = match values.next().map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array"))) { + Some(0) => None, + x => x, + }; + let image_width = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) { + Some(0) => None, + x => x, + }; + let image_height = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) { + Some(0) => None, + x => x, + }; - Some(UrlPreviewData { - title, - description, - image, - image_size, - image_width, - image_height, - }) - } + Some(UrlPreviewData { + title, + description, + image, + image_size, + image_width, + image_height, + }) + } } diff --git a/src/database/key_value/pusher.rs b/src/database/key_value/pusher.rs index 50a6faca..f33eca34 100644 --- a/src/database/key_value/pusher.rs +++ b/src/database/key_value/pusher.rs @@ -1,79 +1,63 @@ use ruma::{ - api::client::push::{set_pusher, Pusher}, - UserId, + api::client::push::{set_pusher, Pusher}, + UserId, }; use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::pusher::Data for KeyValueDatabase { - fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { - match &pusher { - set_pusher::v3::PusherAction::Post(data) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); - self.senderkey_pusher.insert( - &key, - &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"), - )?; - Ok(()) - } - set_pusher::v3::PusherAction::Delete(ids) => { - let mut key = sender.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(ids.pushkey.as_bytes()); - self.senderkey_pusher - .remove(&key) - .map(|_| ()) - .map_err(Into::into) - } - } - } + fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { + match &pusher { + set_pusher::v3::PusherAction::Post(data) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(data.pusher.ids.pushkey.as_bytes()); + self.senderkey_pusher + .insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?; + Ok(()) + }, + set_pusher::v3::PusherAction::Delete(ids) => { + let mut key = sender.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(ids.pushkey.as_bytes()); + self.senderkey_pusher.remove(&key).map(|_| ()).map_err(Into::into) + }, + } + } - fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { - let mut senderkey = sender.as_bytes().to_vec(); - senderkey.push(0xff); - senderkey.extend_from_slice(pushkey.as_bytes()); + fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { + let mut senderkey = sender.as_bytes().to_vec(); + senderkey.push(0xFF); + senderkey.extend_from_slice(pushkey.as_bytes()); - self.senderkey_pusher - .get(&senderkey)? - .map(|push| { - serde_json::from_slice(&push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .transpose() - } + self.senderkey_pusher + .get(&senderkey)? + .map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) + .transpose() + } - fn get_pushers(&self, sender: &UserId) -> Result> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); + fn get_pushers(&self, sender: &UserId) -> Result> { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xFF); - self.senderkey_pusher - .scan_prefix(prefix) - .map(|(_, push)| { - serde_json::from_slice(&push) - .map_err(|_| Error::bad_database("Invalid Pusher in db.")) - }) - .collect() - } + self.senderkey_pusher + .scan_prefix(prefix) + .map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db."))) + .collect() + } - fn get_pushkeys<'a>( - &'a self, - sender: &UserId, - ) -> Box> + 'a> { - let mut prefix = sender.as_bytes().to_vec(); - prefix.push(0xff); + fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a> { + let mut prefix = sender.as_bytes().to_vec(); + prefix.push(0xFF); - Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { - let mut parts = k.splitn(2, |&b| b == 0xff); - let _senderkey = parts.next(); - let push_key = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; - let push_key_string = utils::string_from_bytes(push_key) - .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; + Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| { + let mut parts = k.splitn(2, |&b| b == 0xFF); + let _senderkey = parts.next(); + let push_key = parts.next().ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?; + let push_key_string = utils::string_from_bytes(push_key) + .map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?; - Ok(push_key_string) - })) - } + Ok(push_key_string) + })) + } } diff --git a/src/database/key_value/rooms/alias.rs b/src/database/key_value/rooms/alias.rs index 1b16d9a2..e035784a 100644 --- a/src/database/key_value/rooms/alias.rs +++ b/src/database/key_value/rooms/alias.rs @@ -3,82 +3,68 @@ use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, RoomAli use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; impl service::rooms::alias::Data for KeyValueDatabase { - fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { - self.alias_roomid - .insert(alias.alias().as_bytes(), room_id.as_bytes())?; - let mut aliasid = room_id.as_bytes().to_vec(); - aliasid.push(0xff); - aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; - Ok(()) - } + fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { + self.alias_roomid.insert(alias.alias().as_bytes(), room_id.as_bytes())?; + let mut aliasid = room_id.as_bytes().to_vec(); + aliasid.push(0xFF); + aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + self.aliasid_alias.insert(&aliasid, alias.as_bytes())?; + Ok(()) + } - fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { - if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { - let mut prefix = room_id; - prefix.push(0xff); + fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { + if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? { + let mut prefix = room_id; + prefix.push(0xFF); - for (key, _) in self.aliasid_alias.scan_prefix(prefix) { - self.aliasid_alias.remove(&key)?; - } - self.alias_roomid.remove(alias.alias().as_bytes())?; - } else { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Alias does not exist.", - )); - } - Ok(()) - } + for (key, _) in self.aliasid_alias.scan_prefix(prefix) { + self.aliasid_alias.remove(&key)?; + } + self.alias_roomid.remove(alias.alias().as_bytes())?; + } else { + return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist.")); + } + Ok(()) + } - fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { - self.alias_roomid - .get(alias.alias().as_bytes())? - .map(|bytes| { - RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in alias_roomid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) - }) - .transpose() - } + fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { + self.alias_roomid + .get(alias.alias().as_bytes())? + .map(|bytes| { + RoomId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid.")) + }) + .transpose() + } - fn local_aliases_for_room<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + fn local_aliases_for_room<'a>( + &'a self, room_id: &RoomId, + ) -> Box> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); - Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { - utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? - .try_into() - .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) - })) - } + Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| { + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))? + .try_into() + .map_err(|_| Error::bad_database("Invalid alias in aliasid_alias.")) + })) + } - fn all_local_aliases<'a>( - &'a self, - ) -> Box> + 'a> { - Box::new( - self.alias_roomid - .iter() - .map(|(room_alias_bytes, room_id_bytes)| { - let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) - .map_err(|_| { - Error::bad_database("Invalid alias bytes in aliasid_alias.") - })?; + fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { + Box::new(self.alias_roomid.iter().map(|(room_alias_bytes, room_id_bytes)| { + let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes) + .map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?; - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|_| { - Error::bad_database("Invalid room_id bytes in aliasid_alias.") - })? - .try_into() - .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; + let room_id = utils::string_from_bytes(&room_id_bytes) + .map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))? + .try_into() + .map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?; - Ok((room_id, room_alias_localpart)) - }), - ) - } + Ok((room_id, room_alias_localpart)) + })) + } } diff --git a/src/database/key_value/rooms/auth_chain.rs b/src/database/key_value/rooms/auth_chain.rs index 60057ac1..89e8502f 100644 --- a/src/database/key_value/rooms/auth_chain.rs +++ b/src/database/key_value/rooms/auth_chain.rs @@ -3,59 +3,47 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; use crate::{database::KeyValueDatabase, service, utils, Result}; impl service::rooms::auth_chain::Data for KeyValueDatabase { - fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>>> { - // Check RAM cache - if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { - return Ok(Some(Arc::clone(result))); - } + fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>>> { + // Check RAM cache + if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) { + return Ok(Some(Arc::clone(result))); + } - // We only save auth chains for single events in the db - if key.len() == 1 { - // Check DB cache - let chain = self - .shorteventid_authchain - .get(&key[0].to_be_bytes())? - .map(|chain| { - chain - .chunks_exact(size_of::()) - .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) - .collect() - }); + // We only save auth chains for single events in the db + if key.len() == 1 { + // Check DB cache + let chain = self.shorteventid_authchain.get(&key[0].to_be_bytes())?.map(|chain| { + chain + .chunks_exact(size_of::()) + .map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct")) + .collect() + }); - if let Some(chain) = chain { - let chain = Arc::new(chain); + if let Some(chain) = chain { + let chain = Arc::new(chain); - // Cache in RAM - self.auth_chain_cache - .lock() - .unwrap() - .insert(vec![key[0]], Arc::clone(&chain)); + // Cache in RAM + self.auth_chain_cache.lock().unwrap().insert(vec![key[0]], Arc::clone(&chain)); - return Ok(Some(chain)); - } - } + return Ok(Some(chain)); + } + } - Ok(None) - } + Ok(None) + } - fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { - // Only persist single events in db - if key.len() == 1 { - self.shorteventid_authchain.insert( - &key[0].to_be_bytes(), - &auth_chain - .iter() - .flat_map(|s| s.to_be_bytes().to_vec()) - .collect::>(), - )?; - } + fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { + // Only persist single events in db + if key.len() == 1 { + self.shorteventid_authchain.insert( + &key[0].to_be_bytes(), + &auth_chain.iter().flat_map(|s| s.to_be_bytes().to_vec()).collect::>(), + )?; + } - // Cache in RAM - self.auth_chain_cache - .lock() - .unwrap() - .insert(key, auth_chain); + // Cache in RAM + self.auth_chain_cache.lock().unwrap().insert(key, auth_chain); - Ok(()) - } + Ok(()) + } } diff --git a/src/database/key_value/rooms/directory.rs b/src/database/key_value/rooms/directory.rs index e05dee82..20ccfb55 100644 --- a/src/database/key_value/rooms/directory.rs +++ b/src/database/key_value/rooms/directory.rs @@ -3,26 +3,21 @@ use ruma::{OwnedRoomId, RoomId}; use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::rooms::directory::Data for KeyValueDatabase { - fn set_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.insert(room_id.as_bytes(), &[]) - } + fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) } - fn set_not_public(&self, room_id: &RoomId) -> Result<()> { - self.publicroomids.remove(room_id.as_bytes()) - } + fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) } - fn is_public_room(&self, room_id: &RoomId) -> Result { - Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) - } + fn is_public_room(&self, room_id: &RoomId) -> Result { + Ok(self.publicroomids.get(room_id.as_bytes())?.is_some()) + } - fn public_rooms<'a>(&'a self) -> Box> + 'a> { - Box::new(self.publicroomids.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in publicroomids is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) - })) - } + fn public_rooms<'a>(&'a self) -> Box> + 'a> { + Box::new(self.publicroomids.iter().map(|(bytes, _)| { + RoomId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid.")) + })) + } } diff --git a/src/database/key_value/rooms/edus/presence.rs b/src/database/key_value/rooms/edus/presence.rs index 85f9b848..56feeb03 100644 --- a/src/database/key_value/rooms/edus/presence.rs +++ b/src/database/key_value/rooms/edus/presence.rs @@ -1,178 +1,155 @@ use std::time::Duration; -use ruma::{ - events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId, -}; +use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId}; use tracing::error; use crate::{ - database::KeyValueDatabase, - service::{self, rooms::edus::presence::Presence}, - services, - utils::{self, user_id_from_bytes}, - Error, Result, + database::KeyValueDatabase, + service::{self, rooms::edus::presence::Presence}, + services, + utils::{self, user_id_from_bytes}, + Error, Result, }; impl service::rooms::edus::presence::Data for KeyValueDatabase { - fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let key = presence_key(room_id, user_id); + fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let key = presence_key(room_id, user_id); - self.roomuserid_presence - .get(&key)? - .map(|presence_bytes| -> Result { - Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id) - }) - .transpose() - } + self.roomuserid_presence + .get(&key)? + .map(|presence_bytes| -> Result { + Presence::from_json_bytes(&presence_bytes)?.to_presence_event(user_id) + }) + .transpose() + } - fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> { - let now = utils::millis_since_unix_epoch(); - let mut state_changed = false; + fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> { + let now = utils::millis_since_unix_epoch(); + let mut state_changed = false; - for room_id in services().rooms.state_cache.rooms_joined(user_id) { - let key = presence_key(&room_id?, user_id); + for room_id in services().rooms.state_cache.rooms_joined(user_id) { + let key = presence_key(&room_id?, user_id); - let presence_bytes = self.roomuserid_presence.get(&key)?; + let presence_bytes = self.roomuserid_presence.get(&key)?; - if let Some(presence_bytes) = presence_bytes { - let presence = Presence::from_json_bytes(&presence_bytes)?; - if presence.state != new_state { - state_changed = true; - break; - } - } - } + if let Some(presence_bytes) = presence_bytes { + let presence = Presence::from_json_bytes(&presence_bytes)?; + if presence.state != new_state { + state_changed = true; + break; + } + } + } - let count = if state_changed { - services().globals.next_count()? - } else { - services().globals.current_count()? - }; + let count = if state_changed { + services().globals.next_count()? + } else { + services().globals.current_count()? + }; - for room_id in services().rooms.state_cache.rooms_joined(user_id) { - let key = presence_key(&room_id?, user_id); + for room_id in services().rooms.state_cache.rooms_joined(user_id) { + let key = presence_key(&room_id?, user_id); - let presence_bytes = self.roomuserid_presence.get(&key)?; + let presence_bytes = self.roomuserid_presence.get(&key)?; - let new_presence = match presence_bytes { - Some(presence_bytes) => { - let mut presence = Presence::from_json_bytes(&presence_bytes)?; - presence.state = new_state.clone(); - presence.currently_active = presence.state == PresenceState::Online; - presence.last_active_ts = now; - presence.last_count = count; + let new_presence = match presence_bytes { + Some(presence_bytes) => { + let mut presence = Presence::from_json_bytes(&presence_bytes)?; + presence.state = new_state.clone(); + presence.currently_active = presence.state == PresenceState::Online; + presence.last_active_ts = now; + presence.last_count = count; - presence - } - None => Presence::new( - new_state.clone(), - new_state == PresenceState::Online, - now, - count, - None, - ), - }; + presence + }, + None => Presence::new(new_state.clone(), new_state == PresenceState::Online, now, count, None), + }; - self.roomuserid_presence - .insert(&key, &new_presence.to_json_bytes()?)?; - } + self.roomuserid_presence.insert(&key, &new_presence.to_json_bytes()?)?; + } - let timeout = match new_state { - PresenceState::Online => services().globals.config.presence_idle_timeout_s, - _ => services().globals.config.presence_offline_timeout_s, - }; + let timeout = match new_state { + PresenceState::Online => services().globals.config.presence_idle_timeout_s, + _ => services().globals.config.presence_offline_timeout_s, + }; - self.presence_timer_sender - .send((user_id.to_owned(), Duration::from_secs(timeout))) - .map_err(|e| { - error!("Failed to add presence timer: {}", e); - Error::bad_database("Failed to add presence timer") - }) - } + self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| { + error!("Failed to add presence timer: {}", e); + Error::bad_database("Failed to add presence timer") + }) + } - fn set_presence( - &self, - room_id: &RoomId, - user_id: &UserId, - presence_state: PresenceState, - currently_active: Option, - last_active_ago: Option, - status_msg: Option, - ) -> Result<()> { - let now = utils::millis_since_unix_epoch(); - let last_active_ts = match last_active_ago { - Some(last_active_ago) => now.saturating_sub(last_active_ago.into()), - None => now, - }; + fn set_presence( + &self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option, + last_active_ago: Option, status_msg: Option, + ) -> Result<()> { + let now = utils::millis_since_unix_epoch(); + let last_active_ts = match last_active_ago { + Some(last_active_ago) => now.saturating_sub(last_active_ago.into()), + None => now, + }; - let key = presence_key(room_id, user_id); + let key = presence_key(room_id, user_id); - let presence = Presence::new( - presence_state, - currently_active.unwrap_or(false), - last_active_ts, - services().globals.next_count()?, - status_msg, - ); + let presence = Presence::new( + presence_state, + currently_active.unwrap_or(false), + last_active_ts, + services().globals.next_count()?, + status_msg, + ); - let timeout = match presence.state { - PresenceState::Online => services().globals.config.presence_idle_timeout_s, - _ => services().globals.config.presence_offline_timeout_s, - }; + let timeout = match presence.state { + PresenceState::Online => services().globals.config.presence_idle_timeout_s, + _ => services().globals.config.presence_offline_timeout_s, + }; - self.presence_timer_sender - .send((user_id.to_owned(), Duration::from_secs(timeout))) - .map_err(|e| { - error!("Failed to add presence timer: {}", e); - Error::bad_database("Failed to add presence timer") - })?; + self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| { + error!("Failed to add presence timer: {}", e); + Error::bad_database("Failed to add presence timer") + })?; - self.roomuserid_presence - .insert(&key, &presence.to_json_bytes()?)?; + self.roomuserid_presence.insert(&key, &presence.to_json_bytes()?)?; - Ok(()) - } + Ok(()) + } - fn remove_presence(&self, user_id: &UserId) -> Result<()> { - for room_id in services().rooms.state_cache.rooms_joined(user_id) { - let key = presence_key(&room_id?, user_id); + fn remove_presence(&self, user_id: &UserId) -> Result<()> { + for room_id in services().rooms.state_cache.rooms_joined(user_id) { + let key = presence_key(&room_id?, user_id); - self.roomuserid_presence.remove(&key)?; - } + self.roomuserid_presence.remove(&key)?; + } - Ok(()) - } + Ok(()) + } - fn presence_since<'a>( - &'a self, - room_id: &RoomId, - since: u64, - ) -> Box + 'a> { - let prefix = [room_id.as_bytes(), &[0xff]].concat(); + fn presence_since<'a>( + &'a self, room_id: &RoomId, since: u64, + ) -> Box + 'a> { + let prefix = [room_id.as_bytes(), &[0xFF]].concat(); - Box::new( - self.roomuserid_presence - .scan_prefix(prefix) - .flat_map( - |(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> { - let user_id = user_id_from_bytes( - key.rsplit(|byte| *byte == 0xff).next().ok_or_else(|| { - Error::bad_database("No UserID bytes in presence key") - })?, - )?; + Box::new( + self.roomuserid_presence + .scan_prefix(prefix) + .flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> { + let user_id = user_id_from_bytes( + key.rsplit(|byte| *byte == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("No UserID bytes in presence key"))?, + )?; - let presence = Presence::from_json_bytes(&presence_bytes)?; - let presence_event = presence.to_presence_event(&user_id)?; + let presence = Presence::from_json_bytes(&presence_bytes)?; + let presence_event = presence.to_presence_event(&user_id)?; - Ok((user_id, presence.last_count, presence_event)) - }, - ) - .filter(move |(_, count, _)| *count > since), - ) - } + Ok((user_id, presence.last_count, presence_event)) + }) + .filter(move |(_, count, _)| *count > since), + ) + } } #[inline] fn presence_key(room_id: &RoomId, user_id: &UserId) -> Vec { - [room_id.as_bytes(), &[0xff], user_id.as_bytes()].concat() + [room_id.as_bytes(), &[0xFF], user_id.as_bytes()].concat() } diff --git a/src/database/key_value/rooms/edus/read_receipt.rs b/src/database/key_value/rooms/edus/read_receipt.rs index fa97ea34..ab191b39 100644 --- a/src/database/key_value/rooms/edus/read_receipt.rs +++ b/src/database/key_value/rooms/edus/read_receipt.rs @@ -1,150 +1,113 @@ use std::mem; -use ruma::{ - events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId, -}; +use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; impl service::rooms::edus::read_receipt::Data for KeyValueDatabase { - fn readreceipt_update( - &self, - user_id: &UserId, - room_id: &RoomId, - event: ReceiptEvent, - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - // Remove old entry - if let Some((old, _)) = self - .readreceiptid_readreceipt - .iter_from(&last_possible_key, true) - .take_while(|(key, _)| key.starts_with(&prefix)) - .find(|(key, _)| { - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element") - == user_id.as_bytes() - }) - { - // This is the old room_latest - self.readreceiptid_readreceipt.remove(&old)?; - } + // Remove old entry + if let Some((old, _)) = self + .readreceiptid_readreceipt + .iter_from(&last_possible_key, true) + .take_while(|(key, _)| key.starts_with(&prefix)) + .find(|(key, _)| { + key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element") == user_id.as_bytes() + }) { + // This is the old room_latest + self.readreceiptid_readreceipt.remove(&old)?; + } - let mut room_latest_id = prefix; - room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - room_latest_id.push(0xff); - room_latest_id.extend_from_slice(user_id.as_bytes()); + let mut room_latest_id = prefix; + room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + room_latest_id.push(0xFF); + room_latest_id.extend_from_slice(user_id.as_bytes()); - self.readreceiptid_readreceipt.insert( - &room_latest_id, - &serde_json::to_vec(&event).expect("EduEvent::to_string always works"), - )?; + self.readreceiptid_readreceipt.insert( + &room_latest_id, + &serde_json::to_vec(&event).expect("EduEvent::to_string always works"), + )?; - Ok(()) - } + Ok(()) + } - fn readreceipts_since<'a>( - &'a self, - room_id: &RoomId, - since: u64, - ) -> Box< - dyn Iterator< - Item = Result<( - OwnedUserId, - u64, - Raw, - )>, - > + 'a, - > { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - let prefix2 = prefix.clone(); + fn readreceipts_since<'a>( + &'a self, room_id: &RoomId, since: u64, + ) -> Box)>> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + let prefix2 = prefix.clone(); - let mut first_possible_edu = prefix.clone(); - first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since + let mut first_possible_edu = prefix.clone(); + first_possible_edu.extend_from_slice(&(since + 1).to_be_bytes()); // +1 so we don't send the event at since - Box::new( - self.readreceiptid_readreceipt - .iter_from(&first_possible_edu, false) - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(k, v)| { - let count = utils::u64_from_bytes( - &k[prefix.len()..prefix.len() + mem::size_of::()], - ) - .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; - let user_id = UserId::parse( - utils::string_from_bytes(&k[prefix.len() + mem::size_of::() + 1..]) - .map_err(|_| { - Error::bad_database("Invalid readreceiptid userid bytes in db.") - })?, - ) - .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; + Box::new( + self.readreceiptid_readreceipt + .iter_from(&first_possible_edu, false) + .take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(k, v)| { + let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::()]) + .map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?; + let user_id = UserId::parse( + utils::string_from_bytes(&k[prefix.len() + mem::size_of::() + 1..]) + .map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?, + ) + .map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?; - let mut json = - serde_json::from_slice::(&v).map_err(|_| { - Error::bad_database( - "Read receipt in roomlatestid_roomlatest is invalid json.", - ) - })?; - json.remove("room_id"); + let mut json = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?; + json.remove("room_id"); - Ok(( - user_id, - count, - Raw::from_json( - serde_json::value::to_raw_value(&json) - .expect("json is valid raw value"), - ), - )) - }), - ) - } + Ok(( + user_id, + count, + Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")), + )) + }), + ) + } - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); + fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); - self.roomuserid_privateread - .insert(&key, &count.to_be_bytes())?; + self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?; - self.roomuserid_lastprivatereadupdate - .insert(&key, &services().globals.next_count()?.to_be_bytes()) - } + self.roomuserid_lastprivatereadupdate.insert(&key, &services().globals.next_count()?.to_be_bytes()) + } - fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); + fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); - self.roomuserid_privateread - .get(&key)? - .map_or(Ok(None), |v| { - Ok(Some(utils::u64_from_bytes(&v).map_err(|_| { - Error::bad_database("Invalid private read marker bytes") - })?)) - }) - } + self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| { + Ok(Some( + utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?, + )) + }) + } - fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); + fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); - Ok(self - .roomuserid_lastprivatereadupdate - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") - }) - }) - .transpose()? - .unwrap_or(0)) - } + Ok(self + .roomuserid_lastprivatereadupdate + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) + }) + .transpose()? + .unwrap_or(0)) + } } diff --git a/src/database/key_value/rooms/edus/typing.rs b/src/database/key_value/rooms/edus/typing.rs index 28c06652..e1724aa7 100644 --- a/src/database/key_value/rooms/edus/typing.rs +++ b/src/database/key_value/rooms/edus/typing.rs @@ -5,123 +5,111 @@ use ruma::{OwnedUserId, RoomId, UserId}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; impl service::rooms::edus::typing::Data for KeyValueDatabase { - fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); - let count = services().globals.next_count()?.to_be_bytes(); + let count = services().globals.next_count()?.to_be_bytes(); - let mut room_typing_id = prefix; - room_typing_id.extend_from_slice(&timeout.to_be_bytes()); - room_typing_id.push(0xff); - room_typing_id.extend_from_slice(&count); + let mut room_typing_id = prefix; + room_typing_id.extend_from_slice(&timeout.to_be_bytes()); + room_typing_id.push(0xFF); + room_typing_id.extend_from_slice(&count); - self.typingid_userid - .insert(&room_typing_id, user_id.as_bytes())?; + self.typingid_userid.insert(&room_typing_id, user_id.as_bytes())?; - self.roomid_lasttypingupdate - .insert(room_id.as_bytes(), &count)?; + self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &count)?; - Ok(()) - } + Ok(()) + } - fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); - let user_id = user_id.to_string(); + let user_id = user_id.to_string(); - let mut found_outdated = false; + let mut found_outdated = false; - // Maybe there are multiple ones from calling roomtyping_add multiple times - for outdated_edu in self - .typingid_userid - .scan_prefix(prefix) - .filter(|(_, v)| &**v == user_id.as_bytes()) - { - self.typingid_userid.remove(&outdated_edu.0)?; - found_outdated = true; - } + // Maybe there are multiple ones from calling roomtyping_add multiple times + for outdated_edu in self.typingid_userid.scan_prefix(prefix).filter(|(_, v)| &**v == user_id.as_bytes()) { + self.typingid_userid.remove(&outdated_edu.0)?; + found_outdated = true; + } - if found_outdated { - self.roomid_lasttypingupdate.insert( - room_id.as_bytes(), - &services().globals.next_count()?.to_be_bytes(), - )?; - } + if found_outdated { + self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + } - Ok(()) - } + Ok(()) + } - fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); - let current_timestamp = utils::millis_since_unix_epoch(); + let current_timestamp = utils::millis_since_unix_epoch(); - let mut found_outdated = false; + let mut found_outdated = false; - // Find all outdated edus before inserting a new one - for outdated_edu in self - .typingid_userid - .scan_prefix(prefix) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes( - &key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| { - Error::bad_database("RoomTyping has invalid timestamp or delimiters.") - })?[0..mem::size_of::()], - ) - .map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?, - )) - }) - .filter_map(std::result::Result::ok) - .take_while(|&(_, timestamp)| timestamp < current_timestamp) - { - // This is an outdated edu (time > timestamp) - self.typingid_userid.remove(&outdated_edu.0)?; - found_outdated = true; - } + // Find all outdated edus before inserting a new one + for outdated_edu in self + .typingid_userid + .scan_prefix(prefix) + .map(|(key, _)| { + Ok::<_, Error>(( + key.clone(), + utils::u64_from_bytes( + &key.splitn(2, |&b| b == 0xFF) + .nth(1) + .ok_or_else(|| Error::bad_database("RoomTyping has invalid timestamp or delimiters."))?[0..mem::size_of::()], + ) + .map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?, + )) + }) + .filter_map(std::result::Result::ok) + .take_while(|&(_, timestamp)| timestamp < current_timestamp) + { + // This is an outdated edu (time > timestamp) + self.typingid_userid.remove(&outdated_edu.0)?; + found_outdated = true; + } - if found_outdated { - self.roomid_lasttypingupdate.insert( - room_id.as_bytes(), - &services().globals.next_count()?.to_be_bytes(), - )?; - } + if found_outdated { + self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + } - Ok(()) - } + Ok(()) + } - fn last_typing_update(&self, room_id: &RoomId) -> Result { - Ok(self - .roomid_lasttypingupdate - .get(room_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") - }) - }) - .transpose()? - .unwrap_or(0)) - } + fn last_typing_update(&self, room_id: &RoomId) -> Result { + Ok(self + .roomid_lasttypingupdate + .get(room_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")) + }) + .transpose()? + .unwrap_or(0)) + } - fn typings_all(&self, room_id: &RoomId) -> Result> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + fn typings_all(&self, room_id: &RoomId) -> Result> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); - let mut user_ids = HashSet::new(); + let mut user_ids = HashSet::new(); - for (_, user_id) in self.typingid_userid.scan_prefix(prefix) { - let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| { - Error::bad_database("User ID in typingid_userid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?; + for (_, user_id) in self.typingid_userid.scan_prefix(prefix) { + let user_id = UserId::parse( + utils::string_from_bytes(&user_id) + .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?; - user_ids.insert(user_id); - } + user_ids.insert(user_id); + } - Ok(user_ids) - } + Ok(user_ids) + } } diff --git a/src/database/key_value/rooms/lazy_load.rs b/src/database/key_value/rooms/lazy_load.rs index a19d52cb..080eb4b8 100644 --- a/src/database/key_value/rooms/lazy_load.rs +++ b/src/database/key_value/rooms/lazy_load.rs @@ -3,63 +3,51 @@ use ruma::{DeviceId, RoomId, UserId}; use crate::{database::KeyValueDatabase, service, Result}; impl service::rooms::lazy_loading::Data for KeyValueDatabase { - fn lazy_load_was_sent_before( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - ll_user: &UserId, - ) -> Result { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(ll_user.as_bytes()); - Ok(self.lazyloadedids.get(&key)?.is_some()) - } + fn lazy_load_was_sent_before( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, + ) -> Result { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(ll_user.as_bytes()); + Ok(self.lazyloadedids.get(&key)?.is_some()) + } - fn lazy_load_confirm_delivery( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - confirmed_user_ids: &mut dyn Iterator, - ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); + fn lazy_load_confirm_delivery( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, + confirmed_user_ids: &mut dyn Iterator, + ) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xFF); - for ll_id in confirmed_user_ids { - let mut key = prefix.clone(); - key.extend_from_slice(ll_id.as_bytes()); - self.lazyloadedids.insert(&key, &[])?; - } + for ll_id in confirmed_user_ids { + let mut key = prefix.clone(); + key.extend_from_slice(ll_id.as_bytes()); + self.lazyloadedids.insert(&key, &[])?; + } - Ok(()) - } + Ok(()) + } - fn lazy_load_reset( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - prefix.extend_from_slice(room_id.as_bytes()); - prefix.push(0xff); + fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.extend_from_slice(room_id.as_bytes()); + prefix.push(0xFF); - for (key, _) in self.lazyloadedids.scan_prefix(prefix) { - self.lazyloadedids.remove(&key)?; - } + for (key, _) in self.lazyloadedids.scan_prefix(prefix) { + self.lazyloadedids.remove(&key)?; + } - Ok(()) - } + Ok(()) + } } diff --git a/src/database/key_value/rooms/metadata.rs b/src/database/key_value/rooms/metadata.rs index a97878e1..78fba9d3 100644 --- a/src/database/key_value/rooms/metadata.rs +++ b/src/database/key_value/rooms/metadata.rs @@ -4,76 +4,68 @@ use tracing::error; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; impl service::rooms::metadata::Data for KeyValueDatabase { - fn exists(&self, room_id: &RoomId) -> Result { - let prefix = match services().rooms.short.get_shortroomid(room_id)? { - Some(b) => b.to_be_bytes().to_vec(), - None => return Ok(false), - }; + fn exists(&self, room_id: &RoomId) -> Result { + let prefix = match services().rooms.short.get_shortroomid(room_id)? { + Some(b) => b.to_be_bytes().to_vec(), + None => return Ok(false), + }; - // Look for PDUs in that room. - Ok(self - .pduid_pdu - .iter_from(&prefix, false) - .next() - .filter(|(k, _)| k.starts_with(&prefix)) - .is_some()) - } + // Look for PDUs in that room. + Ok(self.pduid_pdu.iter_from(&prefix, false).next().filter(|(k, _)| k.starts_with(&prefix)).is_some()) + } - fn iter_ids<'a>(&'a self) -> Box> + 'a> { - Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { - RoomId::parse( - utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Room ID in publicroomids is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) - })) - } + fn iter_ids<'a>(&'a self) -> Box> + 'a> { + Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| { + RoomId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid.")) + })) + } - fn is_disabled(&self, room_id: &RoomId) -> Result { - Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) - } + fn is_disabled(&self, room_id: &RoomId) -> Result { + Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some()) + } - fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - if disabled { - self.disabledroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.disabledroomids.remove(room_id.as_bytes())?; - } + fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { + if disabled { + self.disabledroomids.insert(room_id.as_bytes(), &[])?; + } else { + self.disabledroomids.remove(room_id.as_bytes())?; + } - Ok(()) - } + Ok(()) + } - fn is_banned(&self, room_id: &RoomId) -> Result { - Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) - } + fn is_banned(&self, room_id: &RoomId) -> Result { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) } - fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { - if banned { - self.bannedroomids.insert(room_id.as_bytes(), &[])?; - } else { - self.bannedroomids.remove(room_id.as_bytes())?; - } + fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { + if banned { + self.bannedroomids.insert(room_id.as_bytes(), &[])?; + } else { + self.bannedroomids.remove(room_id.as_bytes())?; + } - Ok(()) - } + Ok(()) + } - fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { - Box::new(self.bannedroomids.iter().map( - |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { - let room_id = utils::string_from_bytes(&room_id_bytes) - .map_err(|e| { - error!("Invalid room_id bytes in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids.") - })? - .try_into() - .map_err(|e| { - error!("Invalid room_id in bannedroomids: {e}"); - Error::bad_database("Invalid room_id in bannedroomids") - })?; + fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { + Box::new(self.bannedroomids.iter().map( + |(room_id_bytes, _ /* non-banned rooms should not be in this table */)| { + let room_id = utils::string_from_bytes(&room_id_bytes) + .map_err(|e| { + error!("Invalid room_id bytes in bannedroomids: {e}"); + Error::bad_database("Invalid room_id in bannedroomids.") + })? + .try_into() + .map_err(|e| { + error!("Invalid room_id in bannedroomids: {e}"); + Error::bad_database("Invalid room_id in bannedroomids") + })?; - Ok(room_id) - }, - )) - } + Ok(room_id) + }, + )) + } } diff --git a/src/database/key_value/rooms/outlier.rs b/src/database/key_value/rooms/outlier.rs index 7985ba81..f45e78fc 100644 --- a/src/database/key_value/rooms/outlier.rs +++ b/src/database/key_value/rooms/outlier.rs @@ -3,26 +3,22 @@ use ruma::{CanonicalJsonObject, EventId}; use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result}; impl service::rooms::outlier::Data for KeyValueDatabase { - fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } + fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + } - fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map_or(Ok(None), |pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - } + fn get_outlier_pdu(&self, event_id: &EventId) -> Result> { + self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| { + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) + }) + } - fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.eventid_outlierpdu.insert( - event_id.as_bytes(), - &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), - ) - } + fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { + self.eventid_outlierpdu.insert( + event_id.as_bytes(), + &serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"), + ) + } } diff --git a/src/database/key_value/rooms/pdu_metadata.rs b/src/database/key_value/rooms/pdu_metadata.rs index 0641f9d8..44c621c7 100644 --- a/src/database/key_value/rooms/pdu_metadata.rs +++ b/src/database/key_value/rooms/pdu_metadata.rs @@ -3,85 +3,78 @@ use std::{mem, sync::Arc}; use ruma::{EventId, RoomId, UserId}; use crate::{ - database::KeyValueDatabase, - service::{self, rooms::timeline::PduCount}, - services, utils, Error, PduEvent, Result, + database::KeyValueDatabase, + service::{self, rooms::timeline::PduCount}, + services, utils, Error, PduEvent, Result, }; impl service::rooms::pdu_metadata::Data for KeyValueDatabase { - fn add_relation(&self, from: u64, to: u64) -> Result<()> { - let mut key = to.to_be_bytes().to_vec(); - key.extend_from_slice(&from.to_be_bytes()); - self.tofrom_relation.insert(&key, &[])?; - Ok(()) - } + fn add_relation(&self, from: u64, to: u64) -> Result<()> { + let mut key = to.to_be_bytes().to_vec(); + key.extend_from_slice(&from.to_be_bytes()); + self.tofrom_relation.insert(&key, &[])?; + Ok(()) + } - fn relations_until<'a>( - &'a self, - user_id: &'a UserId, - shortroomid: u64, - target: u64, - until: PduCount, - ) -> Result> + 'a>> { - let prefix = target.to_be_bytes().to_vec(); - let mut current = prefix.clone(); + fn relations_until<'a>( + &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, + ) -> Result> + 'a>> { + let prefix = target.to_be_bytes().to_vec(); + let mut current = prefix.clone(); - let count_raw = match until { - PduCount::Normal(x) => x - 1, - PduCount::Backfilled(x) => { - current.extend_from_slice(&0_u64.to_be_bytes()); - u64::MAX - x - 1 - } - }; - current.extend_from_slice(&count_raw.to_be_bytes()); + let count_raw = match until { + PduCount::Normal(x) => x - 1, + PduCount::Backfilled(x) => { + current.extend_from_slice(&0_u64.to_be_bytes()); + u64::MAX - x - 1 + }, + }; + current.extend_from_slice(&count_raw.to_be_bytes()); - Ok(Box::new( - self.tofrom_relation - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(tofrom, _data)| { - let from = utils::u64_from_bytes(&tofrom[(mem::size_of::())..]) - .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; + Ok(Box::new( + self.tofrom_relation.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map( + move |(tofrom, _data)| { + let from = utils::u64_from_bytes(&tofrom[(mem::size_of::())..]) + .map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?; - let mut pduid = shortroomid.to_be_bytes().to_vec(); - pduid.extend_from_slice(&from.to_be_bytes()); + let mut pduid = shortroomid.to_be_bytes().to_vec(); + pduid.extend_from_slice(&from.to_be_bytes()); - let mut pdu = services() - .rooms - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((PduCount::Normal(from), pdu)) - }), - )) - } + let mut pdu = services() + .rooms + .timeline + .get_pdu_from_id(&pduid)? + .ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((PduCount::Normal(from), pdu)) + }, + ), + )) + } - fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { - for prev in event_ids { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(prev.as_bytes()); - self.referencedevents.insert(&key, &[])?; - } + fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { + for prev in event_ids { + let mut key = room_id.as_bytes().to_vec(); + key.extend_from_slice(prev.as_bytes()); + self.referencedevents.insert(&key, &[])?; + } - Ok(()) - } + Ok(()) + } - fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.extend_from_slice(event_id.as_bytes()); - Ok(self.referencedevents.get(&key)?.is_some()) - } + fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { + let mut key = room_id.as_bytes().to_vec(); + key.extend_from_slice(event_id.as_bytes()); + Ok(self.referencedevents.get(&key)?.is_some()) + } - fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { - self.softfailedeventids.insert(event_id.as_bytes(), &[]) - } + fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { + self.softfailedeventids.insert(event_id.as_bytes(), &[]) + } - fn is_event_soft_failed(&self, event_id: &EventId) -> Result { - self.softfailedeventids - .get(event_id.as_bytes()) - .map(|o| o.is_some()) - } + fn is_event_soft_failed(&self, event_id: &EventId) -> Result { + self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some()) + } } diff --git a/src/database/key_value/rooms/search.rs b/src/database/key_value/rooms/search.rs index 128b3019..05f6f749 100644 --- a/src/database/key_value/rooms/search.rs +++ b/src/database/key_value/rooms/search.rs @@ -5,61 +5,55 @@ use crate::{database::KeyValueDatabase, service, services, utils, Result}; type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; impl service::rooms::search::Data for KeyValueDatabase { - fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - let mut batch = message_body - .split_terminator(|c: char| !c.is_alphanumeric()) - .filter(|s| !s.is_empty()) - .filter(|word| word.len() <= 50) - .map(str::to_lowercase) - .map(|word| { - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(word.as_bytes()); - key.push(0xff); - key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here - (key, Vec::new()) - }); + fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + let mut batch = message_body + .split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .filter(|word| word.len() <= 50) + .map(str::to_lowercase) + .map(|word| { + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(word.as_bytes()); + key.push(0xFF); + key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here + (key, Vec::new()) + }); - self.tokenids.insert_batch(&mut batch) - } + self.tokenids.insert_batch(&mut batch) + } - fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { - let prefix = services() - .rooms - .short - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); + fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> { + let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); - let words: Vec<_> = search_string - .split_terminator(|c: char| !c.is_alphanumeric()) - .filter(|s| !s.is_empty()) - .map(str::to_lowercase) - .collect(); + let words: Vec<_> = search_string + .split_terminator(|c: char| !c.is_alphanumeric()) + .filter(|s| !s.is_empty()) + .map(str::to_lowercase) + .collect(); - let iterators = words.clone().into_iter().map(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xff); - let prefix3 = prefix2.clone(); + let iterators = words.clone().into_iter().map(move |word| { + let mut prefix2 = prefix.clone(); + prefix2.extend_from_slice(word.as_bytes()); + prefix2.push(0xFF); + let prefix3 = prefix2.clone(); - let mut last_possible_id = prefix2.clone(); - last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); + let mut last_possible_id = prefix2.clone(); + last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); - self.tokenids - .iter_from(&last_possible_id, true) // Newest pdus first - .take_while(move |(k, _)| k.starts_with(&prefix2)) - .map(move |(key, _)| key[prefix3.len()..].to_vec()) - }); + self.tokenids + .iter_from(&last_possible_id, true) // Newest pdus first + .take_while(move |(k, _)| k.starts_with(&prefix2)) + .map(move |(key, _)| key[prefix3.len()..].to_vec()) + }); - let common_elements = match utils::common_elements(iterators, |a, b| { - // We compare b with a because we reversed the iterator earlier - b.cmp(a) - }) { - Some(it) => it, - None => return Ok(None), - }; + let common_elements = match utils::common_elements(iterators, |a, b| { + // We compare b with a because we reversed the iterator earlier + b.cmp(a) + }) { + Some(it) => it, + None => return Ok(None), + }; - Ok(Some((Box::new(common_elements), words))) - } + Ok(Some((Box::new(common_elements), words))) + } } diff --git a/src/database/key_value/rooms/short.rs b/src/database/key_value/rooms/short.rs index 40893986..1ae774c1 100644 --- a/src/database/key_value/rooms/short.rs +++ b/src/database/key_value/rooms/short.rs @@ -6,214 +6,165 @@ use tracing::warn; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; impl service::rooms::short::Data for KeyValueDatabase { - fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { - return Ok(*short); - } + fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { + if let Some(short) = self.eventidshort_cache.lock().unwrap().get_mut(event_id) { + return Ok(*short); + } - let short = match self.eventid_shorteventid.get(event_id.as_bytes())? { - Some(shorteventid) => utils::u64_from_bytes(&shorteventid) - .map_err(|_| Error::bad_database("Invalid shorteventid in db."))?, - None => { - let shorteventid = services().globals.next_count()?; - self.eventid_shorteventid - .insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; - self.shorteventid_eventid - .insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; - shorteventid - } - }; + let short = match self.eventid_shorteventid.get(event_id.as_bytes())? { + Some(shorteventid) => { + utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))? + }, + None => { + let shorteventid = services().globals.next_count()?; + self.eventid_shorteventid.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?; + self.shorteventid_eventid.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?; + shorteventid + }, + }; - self.eventidshort_cache - .lock() - .unwrap() - .insert(event_id.to_owned(), short); + self.eventidshort_cache.lock().unwrap().insert(event_id.to_owned(), short); - Ok(short) - } + Ok(short) + } - fn get_shortstatekey( - &self, - event_type: &StateEventType, - state_key: &str, - ) -> Result> { - if let Some(short) = self - .statekeyshort_cache - .lock() - .unwrap() - .get_mut(&(event_type.clone(), state_key.to_owned())) - { - return Ok(Some(*short)); - } + fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { + if let Some(short) = + self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(Some(*short)); + } - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xff); - statekey_vec.extend_from_slice(state_key.as_bytes()); + let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); + statekey_vec.push(0xFF); + statekey_vec.extend_from_slice(state_key.as_bytes()); - let short = self - .statekey_shortstatekey - .get(&statekey_vec)? - .map(|shortstatekey| { - utils::u64_from_bytes(&shortstatekey) - .map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) - }) - .transpose()?; + let short = self + .statekey_shortstatekey + .get(&statekey_vec)? + .map(|shortstatekey| { + utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db.")) + }) + .transpose()?; - if let Some(s) = short { - self.statekeyshort_cache - .lock() - .unwrap() - .insert((event_type.clone(), state_key.to_owned()), s); - } + if let Some(s) = short { + self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), s); + } - Ok(short) - } + Ok(short) + } - fn get_or_create_shortstatekey( - &self, - event_type: &StateEventType, - state_key: &str, - ) -> Result { - if let Some(short) = self - .statekeyshort_cache - .lock() - .unwrap() - .get_mut(&(event_type.clone(), state_key.to_owned())) - { - return Ok(*short); - } + fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + if let Some(short) = + self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned())) + { + return Ok(*short); + } - let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); - statekey_vec.push(0xff); - statekey_vec.extend_from_slice(state_key.as_bytes()); + let mut statekey_vec = event_type.to_string().as_bytes().to_vec(); + statekey_vec.push(0xFF); + statekey_vec.extend_from_slice(state_key.as_bytes()); - let short = match self.statekey_shortstatekey.get(&statekey_vec)? { - Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) - .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, - None => { - let shortstatekey = services().globals.next_count()?; - self.statekey_shortstatekey - .insert(&statekey_vec, &shortstatekey.to_be_bytes())?; - self.shortstatekey_statekey - .insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; - shortstatekey - } - }; + let short = match self.statekey_shortstatekey.get(&statekey_vec)? { + Some(shortstatekey) => utils::u64_from_bytes(&shortstatekey) + .map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?, + None => { + let shortstatekey = services().globals.next_count()?; + self.statekey_shortstatekey.insert(&statekey_vec, &shortstatekey.to_be_bytes())?; + self.shortstatekey_statekey.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?; + shortstatekey + }, + }; - self.statekeyshort_cache - .lock() - .unwrap() - .insert((event_type.clone(), state_key.to_owned()), short); + self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), short); - Ok(short) - } + Ok(short) + } - fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - if let Some(id) = self - .shorteventid_cache - .lock() - .unwrap() - .get_mut(&shorteventid) - { - return Ok(Arc::clone(id)); - } + fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) { + return Ok(Arc::clone(id)); + } - let bytes = self - .shorteventid_eventid - .get(&shorteventid.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; + let bytes = self + .shorteventid_eventid + .get(&shorteventid.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?; - let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in shorteventid_eventid is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; + let event_id = EventId::parse_arc( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?; - self.shorteventid_cache - .lock() - .unwrap() - .insert(shorteventid, Arc::clone(&event_id)); + self.shorteventid_cache.lock().unwrap().insert(shorteventid, Arc::clone(&event_id)); - Ok(event_id) - } + Ok(event_id) + } - fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - if let Some(id) = self - .shortstatekey_cache - .lock() - .unwrap() - .get_mut(&shortstatekey) - { - return Ok(id.clone()); - } + fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + if let Some(id) = self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey) { + return Ok(id.clone()); + } - let bytes = self - .shortstatekey_statekey - .get(&shortstatekey.to_be_bytes())? - .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; + let bytes = self + .shortstatekey_statekey + .get(&shortstatekey.to_be_bytes())? + .ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?; - let mut parts = bytes.splitn(2, |&b| b == 0xff); - let eventtype_bytes = parts.next().expect("split always returns one entry"); - let statekey_bytes = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; + let mut parts = bytes.splitn(2, |&b| b == 0xFF); + let eventtype_bytes = parts.next().expect("split always returns one entry"); + let statekey_bytes = + parts.next().ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?; - let event_type = - StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { - warn!("Event type in shortstatekey_statekey is invalid: {}", e); - Error::bad_database("Event type in shortstatekey_statekey is invalid.") - })?); + let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| { + warn!("Event type in shortstatekey_statekey is invalid: {}", e); + Error::bad_database("Event type in shortstatekey_statekey is invalid.") + })?); - let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| { - Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.") - })?; + let state_key = utils::string_from_bytes(statekey_bytes) + .map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?; - let result = (event_type, state_key); + let result = (event_type, state_key); - self.shortstatekey_cache - .lock() - .unwrap() - .insert(shortstatekey, result.clone()); + self.shortstatekey_cache.lock().unwrap().insert(shortstatekey, result.clone()); - Ok(result) - } + Ok(result) + } - /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - Ok(match self.statehash_shortstatehash.get(state_hash)? { - Some(shortstatehash) => ( - utils::u64_from_bytes(&shortstatehash) - .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, - true, - ), - None => { - let shortstatehash = services().globals.next_count()?; - self.statehash_shortstatehash - .insert(state_hash, &shortstatehash.to_be_bytes())?; - (shortstatehash, false) - } - }) - } + /// Returns (shortstatehash, already_existed) + fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + Ok(match self.statehash_shortstatehash.get(state_hash)? { + Some(shortstatehash) => ( + utils::u64_from_bytes(&shortstatehash) + .map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?, + true, + ), + None => { + let shortstatehash = services().globals.next_count()?; + self.statehash_shortstatehash.insert(state_hash, &shortstatehash.to_be_bytes())?; + (shortstatehash, false) + }, + }) + } - fn get_shortroomid(&self, room_id: &RoomId) -> Result> { - self.roomid_shortroomid - .get(room_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid shortroomid in db.")) - }) - .transpose() - } + fn get_shortroomid(&self, room_id: &RoomId) -> Result> { + self.roomid_shortroomid + .get(room_id.as_bytes())? + .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db."))) + .transpose() + } - fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { - Some(short) => utils::u64_from_bytes(&short) - .map_err(|_| Error::bad_database("Invalid shortroomid in db."))?, - None => { - let short = services().globals.next_count()?; - self.roomid_shortroomid - .insert(room_id.as_bytes(), &short.to_be_bytes())?; - short - } - }) - } + fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { + Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? { + Some(short) => { + utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))? + }, + None => { + let short = services().globals.next_count()?; + self.roomid_shortroomid.insert(room_id.as_bytes(), &short.to_be_bytes())?; + short + }, + }) + } } diff --git a/src/database/key_value/rooms/state.rs b/src/database/key_value/rooms/state.rs index fd0c81e6..f11d2df3 100644 --- a/src/database/key_value/rooms/state.rs +++ b/src/database/key_value/rooms/state.rs @@ -1,73 +1,69 @@ -use ruma::{EventId, OwnedEventId, RoomId}; -use std::collections::HashSet; +use std::{collections::HashSet, sync::Arc}; -use std::sync::Arc; +use ruma::{EventId, OwnedEventId, RoomId}; use tokio::sync::MutexGuard; use crate::{database::KeyValueDatabase, service, utils, Error, Result}; impl service::rooms::state::Data for KeyValueDatabase { - fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.roomid_shortstatehash - .get(room_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") - })?)) - }) - } + fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { + self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(Ok(None), |bytes| { + Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Invalid shortstatehash in roomid_shortstatehash") + })?)) + }) + } - fn set_room_state( - &self, - room_id: &RoomId, - new_shortstatehash: u64, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - self.roomid_shortstatehash - .insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; - Ok(()) - } + fn set_room_state( + &self, + room_id: &RoomId, + new_shortstatehash: u64, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + self.roomid_shortstatehash.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?; + Ok(()) + } - fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { - self.shorteventid_shortstatehash - .insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; - Ok(()) - } + fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> { + self.shorteventid_shortstatehash.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?; + Ok(()) + } - fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); - self.roomid_pduleaves - .scan_prefix(prefix) - .map(|(_, bytes)| { - EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("EventID in roomid_pduleaves is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) - }) - .collect() - } + self.roomid_pduleaves + .scan_prefix(prefix) + .map(|(_, bytes)| { + EventId::parse_arc( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid.")) + }) + .collect() + } - fn set_forward_extremities( - &self, - room_id: &RoomId, - event_ids: Vec, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); + fn set_forward_extremities( + &self, + room_id: &RoomId, + event_ids: Vec, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); - for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { - self.roomid_pduleaves.remove(&key)?; - } + for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) { + self.roomid_pduleaves.remove(&key)?; + } - for event_id in event_ids { - let mut key = prefix.clone(); - key.extend_from_slice(event_id.as_bytes()); - self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; - } + for event_id in event_ids { + let mut key = prefix.clone(); + key.extend_from_slice(event_id.as_bytes()); + self.roomid_pduleaves.insert(&key, event_id.as_bytes())?; + } - Ok(()) - } + Ok(()) + } } diff --git a/src/database/key_value/rooms/state_accessor.rs b/src/database/key_value/rooms/state_accessor.rs index fe40b937..c9cf585c 100644 --- a/src/database/key_value/rooms/state_accessor.rs +++ b/src/database/key_value/rooms/state_accessor.rs @@ -1,186 +1,144 @@ use std::{collections::HashMap, sync::Arc}; -use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; use async_trait::async_trait; use ruma::{events::StateEventType, EventId, RoomId}; +use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; + #[async_trait] impl service::rooms::state_accessor::Data for KeyValueDatabase { - async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { - let full_state = services() - .rooms - .state_compressor - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - let mut result = HashMap::new(); - let mut i = 0; - for compressed in full_state.iter() { - let parsed = services() - .rooms - .state_compressor - .parse_compressed_state_event(compressed)?; - result.insert(parsed.0, parsed.1); + async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + let mut result = HashMap::new(); + let mut i = 0; + for compressed in full_state.iter() { + let parsed = services().rooms.state_compressor.parse_compressed_state_event(compressed)?; + result.insert(parsed.0, parsed.1); - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } - Ok(result) - } + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } + Ok(result) + } - async fn state_full( - &self, - shortstatehash: u64, - ) -> Result>> { - let full_state = services() - .rooms - .state_compressor - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; + async fn state_full(&self, shortstatehash: u64) -> Result>> { + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; - let mut result = HashMap::new(); - let mut i = 0; - for compressed in full_state.iter() { - let (_, eventid) = services() - .rooms - .state_compressor - .parse_compressed_state_event(compressed)?; - if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { - result.insert( - ( - pdu.kind.to_string().into(), - pdu.state_key - .as_ref() - .ok_or_else(|| Error::bad_database("State event has no state key."))? - .clone(), - ), - pdu, - ); - } + let mut result = HashMap::new(); + let mut i = 0; + for compressed in full_state.iter() { + let (_, eventid) = services().rooms.state_compressor.parse_compressed_state_event(compressed)?; + if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? { + result.insert( + ( + pdu.kind.to_string().into(), + pdu.state_key + .as_ref() + .ok_or_else(|| Error::bad_database("State event has no state key."))? + .clone(), + ), + pdu, + ); + } - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } - Ok(result) - } + Ok(result) + } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - fn state_get_id( - &self, - shortstatehash: u64, - event_type: &StateEventType, - state_key: &str, - ) -> Result>> { - let shortstatekey = match services() - .rooms - .short - .get_shortstatekey(event_type, state_key)? - { - Some(s) => s, - None => return Ok(None), - }; - let full_state = services() - .rooms - .state_compressor - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; - Ok(full_state - .iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - .and_then(|compressed| { - services() - .rooms - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - .map(|(_, id)| id) - })) - } + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn state_get_id( + &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + let shortstatekey = match services().rooms.short.get_shortstatekey(event_type, state_key)? { + Some(s) => s, + None => return Ok(None), + }; + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; + Ok( + full_state.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())).and_then(|compressed| { + services().rooms.state_compressor.parse_compressed_state_event(compressed).ok().map(|(_, id)| id) + }), + ) + } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - fn state_get( - &self, - shortstatehash: u64, - event_type: &StateEventType, - state_key: &str, - ) -> Result>> { - self.state_get_id(shortstatehash, event_type, state_key)? - .map_or(Ok(None), |event_id| { - services().rooms.timeline.get_pdu(&event_id) - }) - } + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn state_get( + &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + self.state_get_id(shortstatehash, event_type, state_key)? + .map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id)) + } - /// Returns the state hash for this pdu. - fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { - self.eventid_shorteventid - .get(event_id.as_bytes())? - .map_or(Ok(None), |shorteventid| { - self.shorteventid_shortstatehash - .get(&shorteventid)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database( - "Invalid shortstatehash bytes in shorteventid_shortstatehash", - ) - }) - }) - .transpose() - }) - } + /// Returns the state hash for this pdu. + fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { + self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(Ok(None), |shorteventid| { + self.shorteventid_shortstatehash + .get(&shorteventid)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash")) + }) + .transpose() + }) + } - /// Returns the full room state. - async fn room_state_full( - &self, - room_id: &RoomId, - ) -> Result>> { - if let Some(current_shortstatehash) = - services().rooms.state.get_room_shortstatehash(room_id)? - { - self.state_full(current_shortstatehash).await - } else { - Ok(HashMap::new()) - } - } + /// Returns the full room state. + async fn room_state_full(&self, room_id: &RoomId) -> Result>> { + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + self.state_full(current_shortstatehash).await + } else { + Ok(HashMap::new()) + } + } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - fn room_state_get_id( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = - services().rooms.state.get_room_shortstatehash(room_id)? - { - self.state_get_id(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } - } + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn room_state_get_id( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + self.state_get_id(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - fn room_state_get( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result>> { - if let Some(current_shortstatehash) = - services().rooms.state.get_room_shortstatehash(room_id)? - { - self.state_get(current_shortstatehash, event_type, state_key) - } else { - Ok(None) - } - } + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn room_state_get( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + self.state_get(current_shortstatehash, event_type, state_key) + } else { + Ok(None) + } + } } diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index 76d397f9..c2cb3833 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -2,624 +2,459 @@ use std::{collections::HashSet, sync::Arc}; use regex::Regex; use ruma::{ - api::appservice::Registration, - events::{AnyStrippedStateEvent, AnySyncStateEvent}, - serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + api::appservice::Registration, + events::{AnyStrippedStateEvent, AnySyncStateEvent}, + serde::Raw, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; -type StrippedStateEventIter<'a> = - Box>)>> + 'a>; +type StrippedStateEventIter<'a> = Box>)>> + 'a>; -type AnySyncStateEventIter<'a> = - Box>)>> + 'a>; +type AnySyncStateEventIter<'a> = Box>)>> + 'a>; impl service::rooms::state_cache::Data for KeyValueDatabase { - fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - self.roomuseroncejoinedids.insert(&userroom_id, &[]) - } - - fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_joined.insert(&userroom_id, &[])?; - self.roomuserid_joined.insert(&roomuser_id, &[])?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - Ok(()) - } - - fn mark_as_invited( - &self, - user_id: &UserId, - room_id: &RoomId, - last_state: Option>>, - ) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_invitestate.insert( - &userroom_id, - &serde_json::to_vec(&last_state.unwrap_or_default()) - .expect("state to bytes always works"), - )?; - self.roomuserid_invitecount.insert( - &roomuser_id, - &services().globals.next_count()?.to_be_bytes(), - )?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - Ok(()) - } - - fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - self.userroomid_leftstate.insert( - &userroom_id, - &serde_json::to_vec(&Vec::>::new()).unwrap(), - )?; // TODO - self.roomuserid_leftcount.insert( - &roomuser_id, - &services().globals.next_count()?.to_be_bytes(), - )?; - self.userroomid_joined.remove(&userroom_id)?; - self.roomuserid_joined.remove(&roomuser_id)?; - self.userroomid_invitestate.remove(&userroom_id)?; - self.roomuserid_invitecount.remove(&roomuser_id)?; - - Ok(()) - } - - fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { - let mut joinedcount = 0_u64; - let mut invitedcount = 0_u64; - let mut joined_servers = HashSet::new(); - let mut real_users = HashSet::new(); - - for joined in self - .room_members(room_id) - .filter_map(std::result::Result::ok) - { - joined_servers.insert(joined.server_name().to_owned()); - if joined.server_name() == services().globals.server_name() - && !services().users.is_deactivated(&joined).unwrap_or(true) - { - real_users.insert(joined); - } - joinedcount += 1; - } - - for _invited in self - .room_members_invited(room_id) - .filter_map(std::result::Result::ok) - { - invitedcount += 1; - } - - self.roomid_joinedcount - .insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; - - self.roomid_invitedcount - .insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; - - self.our_real_users_cache - .write() - .unwrap() - .insert(room_id.to_owned(), Arc::new(real_users)); - - for old_joined_server in self - .room_servers(room_id) - .filter_map(std::result::Result::ok) - { - if !joined_servers.remove(&old_joined_server) { - // Server not in room anymore - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); - roomserver_id.extend_from_slice(old_joined_server.as_bytes()); - - let mut serverroom_id = old_joined_server.as_bytes().to_vec(); - serverroom_id.push(0xff); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.remove(&roomserver_id)?; - self.serverroomids.remove(&serverroom_id)?; - } - } - - // Now only new servers are in joined_servers anymore - for server in joined_servers { - let mut roomserver_id = room_id.as_bytes().to_vec(); - roomserver_id.push(0xff); - roomserver_id.extend_from_slice(server.as_bytes()); - - let mut serverroom_id = server.as_bytes().to_vec(); - serverroom_id.push(0xff); - serverroom_id.extend_from_slice(room_id.as_bytes()); - - self.roomserverids.insert(&roomserver_id, &[])?; - self.serverroomids.insert(&serverroom_id, &[])?; - } - - self.appservice_in_room_cache - .write() - .unwrap() - .remove(room_id); - - Ok(()) - } - - #[tracing::instrument(skip(self, room_id))] - fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { - let maybe = self - .our_real_users_cache - .read() - .unwrap() - .get(room_id) - .cloned(); - if let Some(users) = maybe { - Ok(users) - } else { - self.update_joined_count(room_id)?; - Ok(Arc::clone( - self.our_real_users_cache - .read() - .unwrap() - .get(room_id) - .unwrap(), - )) - } - } - - #[tracing::instrument(skip(self, room_id, appservice))] - fn appservice_in_room( - &self, - room_id: &RoomId, - appservice: &(String, Registration), - ) -> Result { - let maybe = self - .appservice_in_room_cache - .read() - .unwrap() - .get(room_id) - .and_then(|map| map.get(&appservice.0)) - .copied(); - - if let Some(b) = maybe { - Ok(b) - } else { - let namespaces = &appservice.1.namespaces; - let users = namespaces - .users - .iter() - .filter_map(|users| Regex::new(users.regex.as_str()).ok()) - .collect::>(); - - let bridge_user_id = UserId::parse_with_server_name( - appservice.1.sender_localpart.as_str(), - services().globals.server_name(), - ) - .ok(); - - let in_room = bridge_user_id - .map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) - || self.room_members(room_id).any(|userid| { - userid.map_or(false, |userid| { - users.iter().any(|r| r.is_match(userid.as_str())) - }) - }); - - self.appservice_in_room_cache - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default() - .insert(appservice.0.clone(), in_room); - - Ok(in_room) - } - } - - /// Makes a user forget a room. - #[tracing::instrument(skip(self))] - fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); - - self.userroomid_leftstate.remove(&userroom_id)?; - self.roomuserid_leftcount.remove(&roomuser_id)?; - - Ok(()) - } - - /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self))] - fn room_servers<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { - ServerName::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Server name in roomserverids is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) - })) - } - - #[tracing::instrument(skip(self))] - fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { - let mut key = server.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - - self.serverroomids.get(&key).map(|o| o.is_some()) - } - - /// Returns an iterator of all rooms a server participates in (as far as we know). - #[tracing::instrument(skip(self))] - fn server_rooms<'a>( - &'a self, - server: &ServerName, - ) -> Box> + 'a> { - let mut prefix = server.as_bytes().to_vec(); - prefix.push(0xff); - - Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, - ) - .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) - })) - } - - /// Returns an iterator over all joined members of a room. - #[tracing::instrument(skip(self))] - fn room_members<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("User ID in roomuserid_joined is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) - })) - } - - #[tracing::instrument(skip(self))] - fn room_joined_count(&self, room_id: &RoomId) -> Result> { - self.roomid_joinedcount - .get(room_id.as_bytes())? - .map(|b| { - utils::u64_from_bytes(&b) - .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) - }) - .transpose() - } - - #[tracing::instrument(skip(self))] - fn room_invited_count(&self, room_id: &RoomId) -> Result> { - self.roomid_invitedcount - .get(room_id.as_bytes())? - .map(|b| { - utils::u64_from_bytes(&b) - .map_err(|_| Error::bad_database("Invalid joinedcount in db.")) - }) - .transpose() - } - - /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self))] - fn room_useroncejoined<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - Box::new( - self.roomuseroncejoinedids - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database( - "User ID in room_useroncejoined is invalid unicode.", - ) - })?, - ) - .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) - }), - ) - } - - /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self))] - fn room_members_invited<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a> { - let mut prefix = room_id.as_bytes().to_vec(); - prefix.push(0xff); - - Box::new( - self.roomuserid_invitecount - .scan_prefix(prefix) - .map(|(key, _)| { - UserId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("User ID in roomuserid_invited is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) - }), - ) - } - - #[tracing::instrument(skip(self))] - fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_invitecount - .get(&key)? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid invitecount in db.") - })?)) - }) - } - - #[tracing::instrument(skip(self))] - fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); - - self.roomuserid_leftcount - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid leftcount in db.")) - }) - .transpose() - } - - /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self))] - fn rooms_joined<'a>( - &'a self, - user_id: &UserId, - ) -> Box> + 'a> { - Box::new( - self.userroomid_joined - .scan_prefix(user_id.as_bytes().to_vec()) - .map(|(key, _)| { - RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_joined is invalid unicode.") - })?, - ) - .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) - }), - ) - } - - /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self))] - fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - Box::new( - self.userroomid_invitestate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid.") - })?; - - let state = serde_json::from_slice(&state).map_err(|_| { - Error::bad_database("Invalid state in userroomid_invitestate.") - })?; - - Ok((room_id, state)) - }), - ) - } - - #[tracing::instrument(skip(self))] - fn invite_state( - &self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - - self.userroomid_invitestate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; - - Ok(state) - }) - .transpose() - } - - #[tracing::instrument(skip(self))] - fn left_state( - &self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result>>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(room_id.as_bytes()); - - self.userroomid_leftstate - .get(&key)? - .map(|state| { - let state = serde_json::from_slice(&state) - .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; - - Ok(state) - }) - .transpose() - } - - /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self))] - fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - Box::new( - self.userroomid_leftstate - .scan_prefix(prefix) - .map(|(key, state)| { - let room_id = RoomId::parse( - utils::string_from_bytes( - key.rsplit(|&b| b == 0xff) - .next() - .expect("rsplit always returns an element"), - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid unicode.") - })?, - ) - .map_err(|_| { - Error::bad_database("Room ID in userroomid_invited is invalid.") - })?; - - let state = serde_json::from_slice(&state).map_err(|_| { - Error::bad_database("Invalid state in userroomid_leftstate.") - })?; - - Ok((room_id, state)) - }), - ) - } - - #[tracing::instrument(skip(self))] - fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) - } - - #[tracing::instrument(skip(self))] - fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - - Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) - } + fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + self.roomuseroncejoinedids.insert(&userroom_id, &[]) + } + + fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_joined.insert(&userroom_id, &[])?; + self.roomuserid_joined.insert(&roomuser_id, &[])?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } + + fn mark_as_invited( + &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, + ) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_invitestate.insert( + &userroom_id, + &serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"), + )?; + self.roomuserid_invitecount.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } + + fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + self.userroomid_leftstate.insert( + &userroom_id, + &serde_json::to_vec(&Vec::>::new()).unwrap(), + )?; // TODO + self.roomuserid_leftcount.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; + self.userroomid_joined.remove(&userroom_id)?; + self.roomuserid_joined.remove(&roomuser_id)?; + self.userroomid_invitestate.remove(&userroom_id)?; + self.roomuserid_invitecount.remove(&roomuser_id)?; + + Ok(()) + } + + fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { + let mut joinedcount = 0_u64; + let mut invitedcount = 0_u64; + let mut joined_servers = HashSet::new(); + let mut real_users = HashSet::new(); + + for joined in self.room_members(room_id).filter_map(std::result::Result::ok) { + joined_servers.insert(joined.server_name().to_owned()); + if joined.server_name() == services().globals.server_name() + && !services().users.is_deactivated(&joined).unwrap_or(true) + { + real_users.insert(joined); + } + joinedcount += 1; + } + + for _invited in self.room_members_invited(room_id).filter_map(std::result::Result::ok) { + invitedcount += 1; + } + + self.roomid_joinedcount.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?; + + self.roomid_invitedcount.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?; + + self.our_real_users_cache.write().unwrap().insert(room_id.to_owned(), Arc::new(real_users)); + + for old_joined_server in self.room_servers(room_id).filter_map(std::result::Result::ok) { + if !joined_servers.remove(&old_joined_server) { + // Server not in room anymore + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xFF); + roomserver_id.extend_from_slice(old_joined_server.as_bytes()); + + let mut serverroom_id = old_joined_server.as_bytes().to_vec(); + serverroom_id.push(0xFF); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.roomserverids.remove(&roomserver_id)?; + self.serverroomids.remove(&serverroom_id)?; + } + } + + // Now only new servers are in joined_servers anymore + for server in joined_servers { + let mut roomserver_id = room_id.as_bytes().to_vec(); + roomserver_id.push(0xFF); + roomserver_id.extend_from_slice(server.as_bytes()); + + let mut serverroom_id = server.as_bytes().to_vec(); + serverroom_id.push(0xFF); + serverroom_id.extend_from_slice(room_id.as_bytes()); + + self.roomserverids.insert(&roomserver_id, &[])?; + self.serverroomids.insert(&serverroom_id, &[])?; + } + + self.appservice_in_room_cache.write().unwrap().remove(room_id); + + Ok(()) + } + + #[tracing::instrument(skip(self, room_id))] + fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { + let maybe = self.our_real_users_cache.read().unwrap().get(room_id).cloned(); + if let Some(users) = maybe { + Ok(users) + } else { + self.update_joined_count(room_id)?; + Ok(Arc::clone(self.our_real_users_cache.read().unwrap().get(room_id).unwrap())) + } + } + + #[tracing::instrument(skip(self, room_id, appservice))] + fn appservice_in_room(&self, room_id: &RoomId, appservice: &(String, Registration)) -> Result { + let maybe = + self.appservice_in_room_cache.read().unwrap().get(room_id).and_then(|map| map.get(&appservice.0)).copied(); + + if let Some(b) = maybe { + Ok(b) + } else { + let namespaces = &appservice.1.namespaces; + let users = + namespaces.users.iter().filter_map(|users| Regex::new(users.regex.as_str()).ok()).collect::>(); + + let bridge_user_id = UserId::parse_with_server_name( + appservice.1.sender_localpart.as_str(), + services().globals.server_name(), + ) + .ok(); + + let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false)) + || self + .room_members(room_id) + .any(|userid| userid.map_or(false, |userid| users.iter().any(|r| r.is_match(userid.as_str())))); + + self.appservice_in_room_cache + .write() + .unwrap() + .entry(room_id.to_owned()) + .or_default() + .insert(appservice.0.clone(), in_room); + + Ok(in_room) + } + } + + /// Makes a user forget a room. + #[tracing::instrument(skip(self))] + fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); + + self.userroomid_leftstate.remove(&userroom_id)?; + self.roomuserid_leftcount.remove(&roomuser_id)?; + + Ok(()) + } + + /// Returns an iterator of all servers participating in this room. + #[tracing::instrument(skip(self))] + fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| { + ServerName::parse( + utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) + .map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Server name in roomserverids is invalid.")) + })) + } + + #[tracing::instrument(skip(self))] + fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { + let mut key = server.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + + self.serverroomids.get(&key).map(|o| o.is_some()) + } + + /// Returns an iterator of all rooms a server participates in (as far as we + /// know). + #[tracing::instrument(skip(self))] + fn server_rooms<'a>(&'a self, server: &ServerName) -> Box> + 'a> { + let mut prefix = server.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| { + RoomId::parse( + utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) + .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("RoomId in serverroomids is invalid.")) + })) + } + + /// Returns an iterator over all joined members of a room. + #[tracing::instrument(skip(self))] + fn room_members<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| { + UserId::parse( + utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) + .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid.")) + })) + } + + #[tracing::instrument(skip(self))] + fn room_joined_count(&self, room_id: &RoomId) -> Result> { + self.roomid_joinedcount + .get(room_id.as_bytes())? + .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) + .transpose() + } + + #[tracing::instrument(skip(self))] + fn room_invited_count(&self, room_id: &RoomId) -> Result> { + self.roomid_invitedcount + .get(room_id.as_bytes())? + .map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db."))) + .transpose() + } + + /// Returns an iterator over all User IDs who ever joined a room. + #[tracing::instrument(skip(self))] + fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.roomuseroncejoinedids.scan_prefix(prefix).map(|(key, _)| { + UserId::parse( + utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) + .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid.")) + })) + } + + /// Returns an iterator over all invited members of a room. + #[tracing::instrument(skip(self))] + fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { + let mut prefix = room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.roomuserid_invitecount.scan_prefix(prefix).map(|(key, _)| { + UserId::parse( + utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) + .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid.")) + })) + } + + #[tracing::instrument(skip(self))] + fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_invitecount.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some( + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?, + )) + }) + } + + #[tracing::instrument(skip(self))] + fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); + + self.roomuserid_leftcount + .get(&key)? + .map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db."))) + .transpose() + } + + /// Returns an iterator over all rooms this user joined. + #[tracing::instrument(skip(self))] + fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { + Box::new(self.userroomid_joined.scan_prefix(user_id.as_bytes().to_vec()).map(|(key, _)| { + RoomId::parse( + utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) + .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid.")) + })) + } + + /// Returns an iterator over all rooms a user was invited to. + #[tracing::instrument(skip(self))] + fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.userroomid_invitestate.scan_prefix(prefix).map(|(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; + + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; + + Ok((room_id, state)) + })) + } + + #[tracing::instrument(skip(self))] + fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + + self.userroomid_invitestate + .get(&key)? + .map(|state| { + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?; + + Ok(state) + }) + .transpose() + } + + #[tracing::instrument(skip(self))] + fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(room_id.as_bytes()); + + self.userroomid_leftstate + .get(&key)? + .map(|state| { + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; + + Ok(state) + }) + .transpose() + } + + /// Returns an iterator over all rooms a user left. + #[tracing::instrument(skip(self))] + fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + Box::new(self.userroomid_leftstate.scan_prefix(prefix).map(|(key, state)| { + let room_id = RoomId::parse( + utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element")) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?; + + let state = serde_json::from_slice(&state) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; + + Ok((room_id, state)) + })) + } + + #[tracing::instrument(skip(self))] + fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some()) + } + + #[tracing::instrument(skip(self))] + fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.userroomid_joined.get(&userroom_id)?.is_some()) + } + + #[tracing::instrument(skip(self))] + fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some()) + } + + #[tracing::instrument(skip(self))] + fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + + Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) + } } diff --git a/src/database/key_value/rooms/state_compressor.rs b/src/database/key_value/rooms/state_compressor.rs index 65ea603e..9be3a196 100644 --- a/src/database/key_value/rooms/state_compressor.rs +++ b/src/database/key_value/rooms/state_compressor.rs @@ -1,61 +1,63 @@ use std::{collections::HashSet, mem::size_of, sync::Arc}; use crate::{ - database::KeyValueDatabase, - service::{self, rooms::state_compressor::data::StateDiff}, - utils, Error, Result, + database::KeyValueDatabase, + service::{self, rooms::state_compressor::data::StateDiff}, + utils, Error, Result, }; impl service::rooms::state_compressor::Data for KeyValueDatabase { - fn get_statediff(&self, shortstatehash: u64) -> Result { - let value = self - .shortstatehash_statediff - .get(&shortstatehash.to_be_bytes())? - .ok_or_else(|| Error::bad_database("State hash does not exist"))?; - let parent = - utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); - let parent = if parent != 0 { Some(parent) } else { None }; + fn get_statediff(&self, shortstatehash: u64) -> Result { + let value = self + .shortstatehash_statediff + .get(&shortstatehash.to_be_bytes())? + .ok_or_else(|| Error::bad_database("State hash does not exist"))?; + let parent = utils::u64_from_bytes(&value[0..size_of::()]).expect("bytes have right length"); + let parent = if parent != 0 { + Some(parent) + } else { + None + }; - let mut add_mode = true; - let mut added = HashSet::new(); - let mut removed = HashSet::new(); + let mut add_mode = true; + let mut added = HashSet::new(); + let mut removed = HashSet::new(); - let mut i = size_of::(); - while let Some(v) = value.get(i..i + 2 * size_of::()) { - if add_mode && v.starts_with(&0_u64.to_be_bytes()) { - add_mode = false; - i += size_of::(); - continue; - } - if add_mode { - added.insert(v.try_into().expect("we checked the size above")); - } else { - removed.insert(v.try_into().expect("we checked the size above")); - } - i += 2 * size_of::(); - } + let mut i = size_of::(); + while let Some(v) = value.get(i..i + 2 * size_of::()) { + if add_mode && v.starts_with(&0_u64.to_be_bytes()) { + add_mode = false; + i += size_of::(); + continue; + } + if add_mode { + added.insert(v.try_into().expect("we checked the size above")); + } else { + removed.insert(v.try_into().expect("we checked the size above")); + } + i += 2 * size_of::(); + } - Ok(StateDiff { - parent, - added: Arc::new(added), - removed: Arc::new(removed), - }) - } + Ok(StateDiff { + parent, + added: Arc::new(added), + removed: Arc::new(removed), + }) + } - fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { - let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); - for new in diff.added.iter() { - value.extend_from_slice(&new[..]); - } + fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()> { + let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec(); + for new in diff.added.iter() { + value.extend_from_slice(&new[..]); + } - if !diff.removed.is_empty() { - value.extend_from_slice(&0_u64.to_be_bytes()); - for removed in diff.removed.iter() { - value.extend_from_slice(&removed[..]); - } - } + if !diff.removed.is_empty() { + value.extend_from_slice(&0_u64.to_be_bytes()); + for removed in diff.removed.iter() { + value.extend_from_slice(&removed[..]); + } + } - self.shortstatehash_statediff - .insert(&shortstatehash.to_be_bytes(), &value) - } + self.shortstatehash_statediff.insert(&shortstatehash.to_be_bytes(), &value) + } } diff --git a/src/database/key_value/rooms/threads.rs b/src/database/key_value/rooms/threads.rs index 794b5fd7..08aaec0d 100644 --- a/src/database/key_value/rooms/threads.rs +++ b/src/database/key_value/rooms/threads.rs @@ -7,74 +7,58 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEven type PduEventIterResult<'a> = Result> + 'a>>; impl service::rooms::threads::Data for KeyValueDatabase { - fn threads_until<'a>( - &'a self, - user_id: &'a UserId, - room_id: &'a RoomId, - until: u64, - _include: &'a IncludeThreads, - ) -> PduEventIterResult<'a> { - let prefix = services() - .rooms - .short - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); + fn threads_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, + ) -> PduEventIterResult<'a> { + let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec(); - let mut current = prefix.clone(); - current.extend_from_slice(&(until - 1).to_be_bytes()); + let mut current = prefix.clone(); + current.extend_from_slice(&(until - 1).to_be_bytes()); - Ok(Box::new( - self.threadid_userids - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pduid, _users)| { - let count = utils::u64_from_bytes(&pduid[(mem::size_of::())..]) - .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; - let mut pdu = services() - .rooms - .timeline - .get_pdu_from_id(&pduid)? - .ok_or_else(|| { - Error::bad_database("Invalid pduid reference in threadid_userids") - })?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - Ok((count, pdu)) - }), - )) - } + Ok(Box::new( + self.threadid_userids.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map( + move |(pduid, _users)| { + let count = utils::u64_from_bytes(&pduid[(mem::size_of::())..]) + .map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?; + let mut pdu = services() + .rooms + .timeline + .get_pdu_from_id(&pduid)? + .ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + Ok((count, pdu)) + }, + ), + )) + } - fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { - let users = participants - .iter() - .map(|user| user.as_bytes()) - .collect::>() - .join(&[0xff][..]); + fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { + let users = participants.iter().map(|user| user.as_bytes()).collect::>().join(&[0xFF][..]); - self.threadid_userids.insert(root_id, &users)?; + self.threadid_userids.insert(root_id, &users)?; - Ok(()) - } + Ok(()) + } - fn get_participants(&self, root_id: &[u8]) -> Result>> { - if let Some(users) = self.threadid_userids.get(root_id)? { - Ok(Some( - users - .split(|b| *b == 0xff) - .map(|bytes| { - UserId::parse(utils::string_from_bytes(bytes).map_err(|_| { - Error::bad_database("Invalid UserId bytes in threadid_userids.") - })?) - .map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) - }) - .filter_map(std::result::Result::ok) - .collect(), - )) - } else { - Ok(None) - } - } + fn get_participants(&self, root_id: &[u8]) -> Result>> { + if let Some(users) = self.threadid_userids.get(root_id)? { + Ok(Some( + users + .split(|b| *b == 0xFF) + .map(|bytes| { + UserId::parse( + utils::string_from_bytes(bytes) + .map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?, + ) + .map_err(|_| Error::bad_database("Invalid UserId in threadid_userids.")) + }) + .filter_map(std::result::Result::ok) + .collect(), + )) + } else { + Ok(None) + } + } } diff --git a/src/database/key_value/rooms/timeline.rs b/src/database/key_value/rooms/timeline.rs index 0331a624..63b57f9b 100644 --- a/src/database/key_value/rooms/timeline.rs +++ b/src/database/key_value/rooms/timeline.rs @@ -1,364 +1,286 @@ use std::{collections::hash_map, mem::size_of, sync::Arc}; -use ruma::{ - api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId, -}; +use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; +use service::rooms::timeline::PduCount; use tracing::error; use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result}; -use service::rooms::timeline::PduCount; - impl service::rooms::timeline::Data for KeyValueDatabase { - fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - match self - .lasttimelinecount_cache - .lock() - .unwrap() - .entry(room_id.to_owned()) - { - hash_map::Entry::Vacant(v) => { - if let Some(last_count) = self - .pdus_until(sender_user, room_id, PduCount::max())? - .find_map(|r| { - // Filter out buggy events - if r.is_err() { - error!("Bad pdu in pdus_since: {:?}", r); - } - r.ok() - }) - { - Ok(*v.insert(last_count.0)) - } else { - Ok(PduCount::Normal(0)) - } - } - hash_map::Entry::Occupied(o) => Ok(*o.get()), - } - } + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + match self.lasttimelinecount_cache.lock().unwrap().entry(room_id.to_owned()) { + hash_map::Entry::Vacant(v) => { + if let Some(last_count) = self.pdus_until(sender_user, room_id, PduCount::max())?.find_map(|r| { + // Filter out buggy events + if r.is_err() { + error!("Bad pdu in pdus_since: {:?}", r); + } + r.ok() + }) { + Ok(*v.insert(last_count.0)) + } else { + Ok(PduCount::Normal(0)) + } + }, + hash_map::Entry::Occupied(o) => Ok(*o.get()), + } + } - /// Returns the `count` of this pdu's id. - fn get_pdu_count(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pdu_id| pdu_count(&pdu_id)) - .transpose() - } + /// Returns the `count` of this pdu's id. + fn get_pdu_count(&self, event_id: &EventId) -> Result> { + self.eventid_pduid.get(event_id.as_bytes())?.map(|pdu_id| pdu_count(&pdu_id)).transpose() + } - /// Returns the json of a pdu. - fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.get_non_outlier_pdu_json(event_id)?.map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - }, - |x| Ok(Some(x)), - ) - } + /// Returns the json of a pdu. + fn get_pdu_json(&self, event_id: &EventId) -> Result> { + self.get_non_outlier_pdu_json(event_id)?.map_or_else( + || { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) + .transpose() + }, + |x| Ok(Some(x)), + ) + } - /// Returns the json of a pdu. - fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - } + /// Returns the json of a pdu. + fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) + .transpose() + } - /// Returns the pdu's id. - fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.eventid_pduid.get(event_id.as_bytes()) - } + /// Returns the pdu's id. + fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.eventid_pduid.get(event_id.as_bytes()) } - /// Returns the pdu. - fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.eventid_pduid - .get(event_id.as_bytes())? - .map(|pduid| { - self.pduid_pdu - .get(&pduid)? - .ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) - }) - .transpose()? - .map(|pdu| { - serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - } + /// Returns the pdu. + fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { + self.eventid_pduid + .get(event_id.as_bytes())? + .map(|pduid| { + self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid.")) + }) + .transpose()? + .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) + .transpose() + } - /// Returns the pdu. - /// - /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - fn get_pdu(&self, event_id: &EventId) -> Result>> { - if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { - return Ok(Some(Arc::clone(p))); - } + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + fn get_pdu(&self, event_id: &EventId) -> Result>> { + if let Some(p) = self.pdu_cache.lock().unwrap().get_mut(event_id) { + return Ok(Some(Arc::clone(p))); + } - if let Some(pdu) = self - .get_non_outlier_pdu(event_id)? - .map_or_else( - || { - self.eventid_outlierpdu - .get(event_id.as_bytes())? - .map(|pdu| { - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db.")) - }) - .transpose() - }, - |x| Ok(Some(x)), - )? - .map(Arc::new) - { - self.pdu_cache - .lock() - .unwrap() - .insert(event_id.to_owned(), Arc::clone(&pdu)); - Ok(Some(pdu)) - } else { - Ok(None) - } - } + if let Some(pdu) = self + .get_non_outlier_pdu(event_id)? + .map_or_else( + || { + self.eventid_outlierpdu + .get(event_id.as_bytes())? + .map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))) + .transpose() + }, + |x| Ok(Some(x)), + )? + .map(Arc::new) + { + self.pdu_cache.lock().unwrap().insert(event_id.to_owned(), Arc::clone(&pdu)); + Ok(Some(pdu)) + } else { + Ok(None) + } + } - /// Returns the pdu. - /// - /// This does __NOT__ check the outliers `Tree`. - fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) - } + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } - /// Returns the pdu as a `BTreeMap`. - fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { - Ok(Some( - serde_json::from_slice(&pdu) - .map_err(|_| Error::bad_database("Invalid PDU in db."))?, - )) - }) - } + /// Returns the pdu as a `BTreeMap`. + fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { + self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| { + Ok(Some( + serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?, + )) + }) + } - fn append_pdu( - &self, - pdu_id: &[u8], - pdu: &PduEvent, - json: &CanonicalJsonObject, - count: u64, - ) -> Result<()> { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; + fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), + )?; - self.lasttimelinecount_cache - .lock() - .unwrap() - .insert(pdu.room_id.clone(), PduCount::Normal(count)); + self.lasttimelinecount_cache.lock().unwrap().insert(pdu.room_id.clone(), PduCount::Normal(count)); - self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; + self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?; + self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?; - Ok(()) - } + Ok(()) + } - fn prepend_backfill_pdu( - &self, - pdu_id: &[u8], - event_id: &EventId, - json: &CanonicalJsonObject, - ) -> Result<()> { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), - )?; + fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"), + )?; - self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; - self.eventid_outlierpdu.remove(event_id.as_bytes())?; + self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?; + self.eventid_outlierpdu.remove(event_id.as_bytes())?; - Ok(()) - } + Ok(()) + } - /// Removes a pdu and creates a new one with the same id. - fn replace_pdu( - &self, - pdu_id: &[u8], - pdu_json: &CanonicalJsonObject, - pdu: &PduEvent, - ) -> Result<()> { - if self.pduid_pdu.get(pdu_id)?.is_some() { - self.pduid_pdu.insert( - pdu_id, - &serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"), - )?; - } else { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "PDU does not exist.", - )); - } + /// Removes a pdu and creates a new one with the same id. + fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { + if self.pduid_pdu.get(pdu_id)?.is_some() { + self.pduid_pdu.insert( + pdu_id, + &serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"), + )?; + } else { + return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist.")); + } - self.pdu_cache - .lock() - .unwrap() - .remove(&(*pdu.event_id).to_owned()); + self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned()); - Ok(()) - } + Ok(()) + } - /// Returns an iterator over all events and their tokens in a room that happened before the - /// event with id `until` in reverse-chronological order. - fn pdus_until<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - until: PduCount, - ) -> Result> + 'a>> { - let (prefix, current) = count_to_id(room_id, until, 1, true)?; + /// Returns an iterator over all events and their tokens in a room that + /// happened before the event with id `until` in reverse-chronological + /// order. + fn pdus_until<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, + ) -> Result> + 'a>> { + let (prefix, current) = count_to_id(room_id, until, 1, true)?; - let user_id = user_id.to_owned(); + let user_id = user_id.to_owned(); - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) - } + Ok(Box::new( + self.pduid_pdu.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map( + move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + pdu.add_age()?; + let count = pdu_count(&pdu_id)?; + Ok((count, pdu)) + }, + ), + )) + } - fn pdus_after<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - from: PduCount, - ) -> Result> + 'a>> { - let (prefix, current) = count_to_id(room_id, from, 1, false)?; + fn pdus_after<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, + ) -> Result> + 'a>> { + let (prefix, current) = count_to_id(room_id, from, 1, false)?; - let user_id = user_id.to_owned(); + let user_id = user_id.to_owned(); - Ok(Box::new( - self.pduid_pdu - .iter_from(¤t, false) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(move |(pdu_id, v)| { - let mut pdu = serde_json::from_slice::(&v) - .map_err(|_| Error::bad_database("PDU in db is invalid."))?; - if pdu.sender != user_id { - pdu.remove_transaction_id()?; - } - pdu.add_age()?; - let count = pdu_count(&pdu_id)?; - Ok((count, pdu)) - }), - )) - } + Ok(Box::new( + self.pduid_pdu.iter_from(¤t, false).take_while(move |(k, _)| k.starts_with(&prefix)).map( + move |(pdu_id, v)| { + let mut pdu = serde_json::from_slice::(&v) + .map_err(|_| Error::bad_database("PDU in db is invalid."))?; + if pdu.sender != user_id { + pdu.remove_transaction_id()?; + } + pdu.add_age()?; + let count = pdu_count(&pdu_id)?; + Ok((count, pdu)) + }, + ), + )) + } - fn increment_notification_counts( - &self, - room_id: &RoomId, - notifies: Vec, - highlights: Vec, - ) -> Result<()> { - let mut notifies_batch = Vec::new(); - let mut highlights_batch = Vec::new(); - for user in notifies { - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - notifies_batch.push(userroom_id); - } - for user in highlights { - let mut userroom_id = user.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - highlights_batch.push(userroom_id); - } + fn increment_notification_counts( + &self, room_id: &RoomId, notifies: Vec, highlights: Vec, + ) -> Result<()> { + let mut notifies_batch = Vec::new(); + let mut highlights_batch = Vec::new(); + for user in notifies { + let mut userroom_id = user.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + notifies_batch.push(userroom_id); + } + for user in highlights { + let mut userroom_id = user.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + highlights_batch.push(userroom_id); + } - self.userroomid_notificationcount - .increment_batch(&mut notifies_batch.into_iter())?; - self.userroomid_highlightcount - .increment_batch(&mut highlights_batch.into_iter())?; - Ok(()) - } + self.userroomid_notificationcount.increment_batch(&mut notifies_batch.into_iter())?; + self.userroomid_highlightcount.increment_batch(&mut highlights_batch.into_iter())?; + Ok(()) + } } /// Returns the `count` of this pdu's id. fn pdu_count(pdu_id: &[u8]) -> Result { - let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) - .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; - let second_last_u64 = utils::u64_from_bytes( - &pdu_id[pdu_id.len() - 2 * size_of::()..pdu_id.len() - size_of::()], - ); + let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::()..]) + .map_err(|_| Error::bad_database("PDU has invalid count bytes."))?; + let second_last_u64 = + utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::()..pdu_id.len() - size_of::()]); - if matches!(second_last_u64, Ok(0)) { - Ok(PduCount::Backfilled(u64::MAX - last_u64)) - } else { - Ok(PduCount::Normal(last_u64)) - } + if matches!(second_last_u64, Ok(0)) { + Ok(PduCount::Backfilled(u64::MAX - last_u64)) + } else { + Ok(PduCount::Normal(last_u64)) + } } -fn count_to_id( - room_id: &RoomId, - count: PduCount, - offset: u64, - subtract: bool, -) -> Result<(Vec, Vec)> { - let prefix = services() - .rooms - .short - .get_shortroomid(room_id)? - .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? - .to_be_bytes() - .to_vec(); - let mut pdu_id = prefix.clone(); - // +1 so we don't send the base event - let count_raw = match count { - PduCount::Normal(x) => { - if subtract { - x - offset - } else { - x + offset - } - } - PduCount::Backfilled(x) => { - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - let num = u64::MAX - x; - if subtract { - if num > 0 { - num - offset - } else { - num - } - } else { - num + offset - } - } - }; - pdu_id.extend_from_slice(&count_raw.to_be_bytes()); +fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec, Vec)> { + let prefix = services() + .rooms + .short + .get_shortroomid(room_id)? + .ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))? + .to_be_bytes() + .to_vec(); + let mut pdu_id = prefix.clone(); + // +1 so we don't send the base event + let count_raw = match count { + PduCount::Normal(x) => { + if subtract { + x - offset + } else { + x + offset + } + }, + PduCount::Backfilled(x) => { + pdu_id.extend_from_slice(&0_u64.to_be_bytes()); + let num = u64::MAX - x; + if subtract { + if num > 0 { + num - offset + } else { + num + } + } else { + num + offset + } + }, + }; + pdu_id.extend_from_slice(&count_raw.to_be_bytes()); - Ok((prefix, pdu_id)) + Ok((prefix, pdu_id)) } diff --git a/src/database/key_value/rooms/user.rs b/src/database/key_value/rooms/user.rs index f439f481..2298c16d 100644 --- a/src/database/key_value/rooms/user.rs +++ b/src/database/key_value/rooms/user.rs @@ -3,147 +3,122 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}; impl service::rooms::user::Data for KeyValueDatabase { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xff); - roomuser_id.extend_from_slice(user_id.as_bytes()); + fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); + let mut roomuser_id = room_id.as_bytes().to_vec(); + roomuser_id.push(0xFF); + roomuser_id.extend_from_slice(user_id.as_bytes()); - self.userroomid_notificationcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; - self.userroomid_highlightcount - .insert(&userroom_id, &0_u64.to_be_bytes())?; + self.userroomid_notificationcount.insert(&userroom_id, &0_u64.to_be_bytes())?; + self.userroomid_highlightcount.insert(&userroom_id, &0_u64.to_be_bytes())?; - self.roomuserid_lastnotificationread.insert( - &roomuser_id, - &services().globals.next_count()?.to_be_bytes(), - )?; + self.roomuserid_lastnotificationread.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?; - Ok(()) - } + Ok(()) + } - fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); - self.userroomid_notificationcount - .get(&userroom_id)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid notification count in db.")) - }) - .unwrap_or(Ok(0)) - } + self.userroomid_notificationcount + .get(&userroom_id)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db.")) + }) + .unwrap_or(Ok(0)) + } - fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut userroom_id = user_id.as_bytes().to_vec(); - userroom_id.push(0xff); - userroom_id.extend_from_slice(room_id.as_bytes()); + fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut userroom_id = user_id.as_bytes().to_vec(); + userroom_id.push(0xFF); + userroom_id.extend_from_slice(room_id.as_bytes()); - self.userroomid_highlightcount - .get(&userroom_id)? - .map(|bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid highlight count in db.")) - }) - .unwrap_or(Ok(0)) - } + self.userroomid_highlightcount + .get(&userroom_id)? + .map(|bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db.")) + }) + .unwrap_or(Ok(0)) + } - fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(user_id.as_bytes()); + fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id.as_bytes()); - Ok(self - .roomuserid_lastnotificationread - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.") - }) - }) - .transpose()? - .unwrap_or(0)) - } + Ok(self + .roomuserid_lastnotificationread + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")) + }) + .transpose()? + .unwrap_or(0)) + } - fn associate_token_shortstatehash( - &self, - room_id: &RoomId, - token: u64, - shortstatehash: u64, - ) -> Result<()> { - let shortroomid = services() - .rooms - .short - .get_shortroomid(room_id)? - .expect("room exists"); + fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { + let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists"); - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); - self.roomsynctoken_shortstatehash - .insert(&key, &shortstatehash.to_be_bytes()) - } + self.roomsynctoken_shortstatehash.insert(&key, &shortstatehash.to_be_bytes()) + } - fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - let shortroomid = services() - .rooms - .short - .get_shortroomid(room_id)? - .expect("room exists"); + fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { + let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists"); - let mut key = shortroomid.to_be_bytes().to_vec(); - key.extend_from_slice(&token.to_be_bytes()); + let mut key = shortroomid.to_be_bytes().to_vec(); + key.extend_from_slice(&token.to_be_bytes()); - self.roomsynctoken_shortstatehash - .get(&key)? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash") - }) - }) - .transpose() - } + self.roomsynctoken_shortstatehash + .get(&key)? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")) + }) + .transpose() + } - fn get_shared_rooms<'a>( - &'a self, - users: Vec, - ) -> Result> + 'a>> { - let iterators = users.into_iter().map(move |user_id| { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); + fn get_shared_rooms<'a>( + &'a self, users: Vec, + ) -> Result> + 'a>> { + let iterators = users.into_iter().map(move |user_id| { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); - self.userroomid_joined - .scan_prefix(prefix) - .map(|(key, _)| { - let roomid_index = key - .iter() - .enumerate() - .find(|(_, &b)| b == 0xff) - .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? - .0 - + 1; // +1 because the room id starts AFTER the separator + self.userroomid_joined + .scan_prefix(prefix) + .map(|(key, _)| { + let roomid_index = key + .iter() + .enumerate() + .find(|(_, &b)| b == 0xFF) + .ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))? + .0 + 1; // +1 because the room id starts AFTER the separator - let room_id = key[roomid_index..].to_vec(); + let room_id = key[roomid_index..].to_vec(); - Ok::<_, Error>(room_id) - }) - .filter_map(std::result::Result::ok) - }); + Ok::<_, Error>(room_id) + }) + .filter_map(std::result::Result::ok) + }); - // We use the default compare function because keys are sorted correctly (not reversed) - Ok(Box::new( - utils::common_elements(iterators, Ord::cmp) - .expect("users is not empty") - .map(|bytes| { - RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Invalid RoomId bytes in userroomid_joined") - })?) - .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) - }), - )) - } + // We use the default compare function because keys are sorted correctly (not + // reversed) + Ok(Box::new( + utils::common_elements(iterators, Ord::cmp).expect("users is not empty").map(|bytes| { + RoomId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?, + ) + .map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined.")) + }), + )) + } } diff --git a/src/database/key_value/sending.rs b/src/database/key_value/sending.rs index 6c8e939b..a3ede405 100644 --- a/src/database/key_value/sending.rs +++ b/src/database/key_value/sending.rs @@ -1,205 +1,181 @@ use ruma::{ServerName, UserId}; use crate::{ - database::KeyValueDatabase, - service::{ - self, - sending::{OutgoingKind, SendingEventType}, - }, - services, utils, Error, Result, + database::KeyValueDatabase, + service::{ + self, + sending::{OutgoingKind, SendingEventType}, + }, + services, utils, Error, Result, }; impl service::sending::Data for KeyValueDatabase { - fn active_requests<'a>( - &'a self, - ) -> Box, OutgoingKind, SendingEventType)>> + 'a> { - Box::new( - self.servercurrentevent_data - .iter() - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), - ) - } + fn active_requests<'a>( + &'a self, + ) -> Box, OutgoingKind, SendingEventType)>> + 'a> { + Box::new( + self.servercurrentevent_data + .iter() + .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(k, e)| (key, k, e))), + ) + } - fn active_requests_for<'a>( - &'a self, - outgoing_kind: &OutgoingKind, - ) -> Box, SendingEventType)>> + 'a> { - let prefix = outgoing_kind.get_prefix(); - Box::new( - self.servercurrentevent_data - .scan_prefix(prefix) - .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), - ) - } + fn active_requests_for<'a>( + &'a self, outgoing_kind: &OutgoingKind, + ) -> Box, SendingEventType)>> + 'a> { + let prefix = outgoing_kind.get_prefix(); + Box::new( + self.servercurrentevent_data + .scan_prefix(prefix) + .map(|(key, v)| parse_servercurrentevent(&key, v).map(|(_, e)| (key, e))), + ) + } - fn delete_active_request(&self, key: Vec) -> Result<()> { - self.servercurrentevent_data.remove(&key) - } + fn delete_active_request(&self, key: Vec) -> Result<()> { self.servercurrentevent_data.remove(&key) } - fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { - let prefix = outgoing_kind.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { - self.servercurrentevent_data.remove(&key)?; - } + fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { + let prefix = outgoing_kind.get_prefix(); + for (key, _) in self.servercurrentevent_data.scan_prefix(prefix) { + self.servercurrentevent_data.remove(&key)?; + } - Ok(()) - } + Ok(()) + } - fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { - let prefix = outgoing_kind.get_prefix(); - for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { - self.servercurrentevent_data.remove(&key).unwrap(); - } + fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> { + let prefix = outgoing_kind.get_prefix(); + for (key, _) in self.servercurrentevent_data.scan_prefix(prefix.clone()) { + self.servercurrentevent_data.remove(&key).unwrap(); + } - for (key, _) in self.servernameevent_data.scan_prefix(prefix) { - self.servernameevent_data.remove(&key).unwrap(); - } + for (key, _) in self.servernameevent_data.scan_prefix(prefix) { + self.servernameevent_data.remove(&key).unwrap(); + } - Ok(()) - } + Ok(()) + } - fn queue_requests( - &self, - requests: &[(&OutgoingKind, SendingEventType)], - ) -> Result>> { - let mut batch = Vec::new(); - let mut keys = Vec::new(); - for (outgoing_kind, event) in requests { - let mut key = outgoing_kind.get_prefix(); - if let SendingEventType::Pdu(value) = &event { - key.extend_from_slice(value); - } else { - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - } - let value = if let SendingEventType::Edu(value) = &event { - &**value - } else { - &[] - }; - batch.push((key.clone(), value.to_owned())); - keys.push(key); - } - self.servernameevent_data - .insert_batch(&mut batch.into_iter())?; - Ok(keys) - } + fn queue_requests(&self, requests: &[(&OutgoingKind, SendingEventType)]) -> Result>> { + let mut batch = Vec::new(); + let mut keys = Vec::new(); + for (outgoing_kind, event) in requests { + let mut key = outgoing_kind.get_prefix(); + if let SendingEventType::Pdu(value) = &event { + key.extend_from_slice(value); + } else { + key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + } + let value = if let SendingEventType::Edu(value) = &event { + &**value + } else { + &[] + }; + batch.push((key.clone(), value.to_owned())); + keys.push(key); + } + self.servernameevent_data.insert_batch(&mut batch.into_iter())?; + Ok(keys) + } - fn queued_requests<'a>( - &'a self, - outgoing_kind: &OutgoingKind, - ) -> Box)>> + 'a> { - let prefix = outgoing_kind.get_prefix(); - return Box::new( - self.servernameevent_data - .scan_prefix(prefix) - .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), - ); - } + fn queued_requests<'a>( + &'a self, outgoing_kind: &OutgoingKind, + ) -> Box)>> + 'a> { + let prefix = outgoing_kind.get_prefix(); + return Box::new( + self.servernameevent_data + .scan_prefix(prefix) + .map(|(k, v)| parse_servercurrentevent(&k, v).map(|(_, ev)| (ev, k))), + ); + } - fn mark_as_active(&self, events: &[(SendingEventType, Vec)]) -> Result<()> { - for (e, key) in events { - let value = if let SendingEventType::Edu(value) = &e { - &**value - } else { - &[] - }; - self.servercurrentevent_data.insert(key, value)?; - self.servernameevent_data.remove(key)?; - } + fn mark_as_active(&self, events: &[(SendingEventType, Vec)]) -> Result<()> { + for (e, key) in events { + let value = if let SendingEventType::Edu(value) = &e { + &**value + } else { + &[] + }; + self.servercurrentevent_data.insert(key, value)?; + self.servernameevent_data.remove(key)?; + } - Ok(()) - } + Ok(()) + } - fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { - self.servername_educount - .insert(server_name.as_bytes(), &last_count.to_be_bytes()) - } + fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> { + self.servername_educount.insert(server_name.as_bytes(), &last_count.to_be_bytes()) + } - fn get_latest_educount(&self, server_name: &ServerName) -> Result { - self.servername_educount - .get(server_name.as_bytes())? - .map_or(Ok(0), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) - }) - } + fn get_latest_educount(&self, server_name: &ServerName) -> Result { + self.servername_educount.get(server_name.as_bytes())?.map_or(Ok(0), |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount.")) + }) + } } #[tracing::instrument(skip(key))] -fn parse_servercurrentevent( - key: &[u8], - value: Vec, -) -> Result<(OutgoingKind, SendingEventType)> { - // Appservices start with a plus - Ok::<_, Error>(if key.starts_with(b"+") { - let mut parts = key[1..].splitn(2, |&b| b == 0xff); +fn parse_servercurrentevent(key: &[u8], value: Vec) -> Result<(OutgoingKind, SendingEventType)> { + // Appservices start with a plus + Ok::<_, Error>(if key.starts_with(b"+") { + let mut parts = key[1..].splitn(2, |&b| b == 0xFF); - let server = parts.next().expect("splitn always returns one element"); - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let server = parts.next().expect("splitn always returns one element"); + let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let server = utils::string_from_bytes(server).map_err(|_| { - Error::bad_database("Invalid server bytes in server_currenttransaction") - })?; + let server = utils::string_from_bytes(server) + .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?; - ( - OutgoingKind::Appservice(server), - if value.is_empty() { - SendingEventType::Pdu(event.to_vec()) - } else { - SendingEventType::Edu(value) - }, - ) - } else if key.starts_with(b"$") { - let mut parts = key[1..].splitn(3, |&b| b == 0xff); + ( + OutgoingKind::Appservice(server), + if value.is_empty() { + SendingEventType::Pdu(event.to_vec()) + } else { + SendingEventType::Edu(value) + }, + ) + } else if key.starts_with(b"$") { + let mut parts = key[1..].splitn(3, |&b| b == 0xFF); - let user = parts.next().expect("splitn always returns one element"); - let user_string = utils::string_from_bytes(user) - .map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?; - let user_id = UserId::parse(user_string) - .map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; + let user = parts.next().expect("splitn always returns one element"); + let user_string = utils::string_from_bytes(user) + .map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?; + let user_id = + UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?; - let pushkey = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let pushkey_string = utils::string_from_bytes(pushkey) - .map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?; + let pushkey = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let pushkey_string = utils::string_from_bytes(pushkey) + .map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?; - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - ( - OutgoingKind::Push(user_id, pushkey_string), - if value.is_empty() { - SendingEventType::Pdu(event.to_vec()) - } else { - // I'm pretty sure this should never be called - SendingEventType::Edu(value) - }, - ) - } else { - let mut parts = key.splitn(2, |&b| b == 0xff); + ( + OutgoingKind::Push(user_id, pushkey_string), + if value.is_empty() { + SendingEventType::Pdu(event.to_vec()) + } else { + // I'm pretty sure this should never be called + SendingEventType::Edu(value) + }, + ) + } else { + let mut parts = key.splitn(2, |&b| b == 0xFF); - let server = parts.next().expect("splitn always returns one element"); - let event = parts - .next() - .ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; + let server = parts.next().expect("splitn always returns one element"); + let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?; - let server = utils::string_from_bytes(server).map_err(|_| { - Error::bad_database("Invalid server bytes in server_currenttransaction") - })?; + let server = utils::string_from_bytes(server) + .map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?; - ( - OutgoingKind::Normal(ServerName::parse(server).map_err(|_| { - Error::bad_database("Invalid server string in server_currenttransaction") - })?), - if value.is_empty() { - SendingEventType::Pdu(event.to_vec()) - } else { - SendingEventType::Edu(value) - }, - ) - }) + ( + OutgoingKind::Normal( + ServerName::parse(server) + .map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?, + ), + if value.is_empty() { + SendingEventType::Pdu(event.to_vec()) + } else { + SendingEventType::Edu(value) + }, + ) + }) } diff --git a/src/database/key_value/transaction_ids.rs b/src/database/key_value/transaction_ids.rs index b3bd05f4..f88ae69f 100644 --- a/src/database/key_value/transaction_ids.rs +++ b/src/database/key_value/transaction_ids.rs @@ -3,37 +3,30 @@ use ruma::{DeviceId, TransactionId, UserId}; use crate::{database::KeyValueDatabase, service, Result}; impl service::transaction_ids::Data for KeyValueDatabase { - fn add_txnid( - &self, - user_id: &UserId, - device_id: Option<&DeviceId>, - txn_id: &TransactionId, - data: &[u8], - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xff); - key.extend_from_slice(txn_id.as_bytes()); + fn add_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); + key.push(0xFF); + key.extend_from_slice(txn_id.as_bytes()); - self.userdevicetxnid_response.insert(&key, data)?; + self.userdevicetxnid_response.insert(&key, data)?; - Ok(()) - } + Ok(()) + } - fn existing_txnid( - &self, - user_id: &UserId, - device_id: Option<&DeviceId>, - txn_id: &TransactionId, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); - key.push(0xff); - key.extend_from_slice(txn_id.as_bytes()); + fn existing_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, + ) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default()); + key.push(0xFF); + key.extend_from_slice(txn_id.as_bytes()); - // If there's no entry, this is a new transaction - self.userdevicetxnid_response.get(&key) - } + // If there's no entry, this is a new transaction + self.userdevicetxnid_response.get(&key) + } } diff --git a/src/database/key_value/uiaa.rs b/src/database/key_value/uiaa.rs index 652c12d0..de200614 100644 --- a/src/database/key_value/uiaa.rs +++ b/src/database/key_value/uiaa.rs @@ -1,89 +1,64 @@ use ruma::{ - api::client::{error::ErrorKind, uiaa::UiaaInfo}, - CanonicalJsonValue, DeviceId, UserId, + api::client::{error::ErrorKind, uiaa::UiaaInfo}, + CanonicalJsonValue, DeviceId, UserId, }; use crate::{database::KeyValueDatabase, service, Error, Result}; impl service::uiaa::Data for KeyValueDatabase { - fn set_uiaa_request( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - request: &CanonicalJsonValue, - ) -> Result<()> { - self.userdevicesessionid_uiaarequest - .write() - .unwrap() - .insert( - (user_id.to_owned(), device_id.to_owned(), session.to_owned()), - request.to_owned(), - ); + fn set_uiaa_request( + &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, + ) -> Result<()> { + self.userdevicesessionid_uiaarequest.write().unwrap().insert( + (user_id.to_owned(), device_id.to_owned(), session.to_owned()), + request.to_owned(), + ); - Ok(()) - } + Ok(()) + } - fn get_uiaa_request( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - ) -> Option { - self.userdevicesessionid_uiaarequest - .read() - .unwrap() - .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) - .map(std::borrow::ToOwned::to_owned) - } + fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option { + self.userdevicesessionid_uiaarequest + .read() + .unwrap() + .get(&(user_id.to_owned(), device_id.to_owned(), session.to_owned())) + .map(std::borrow::ToOwned::to_owned) + } - fn update_uiaa_session( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - uiaainfo: Option<&UiaaInfo>, - ) -> Result<()> { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); + fn update_uiaa_session( + &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, + ) -> Result<()> { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(session.as_bytes()); - if let Some(uiaainfo) = uiaainfo { - self.userdevicesessionid_uiaainfo.insert( - &userdevicesessionid, - &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), - )?; - } else { - self.userdevicesessionid_uiaainfo - .remove(&userdevicesessionid)?; - } + if let Some(uiaainfo) = uiaainfo { + self.userdevicesessionid_uiaainfo.insert( + &userdevicesessionid, + &serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"), + )?; + } else { + self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?; + } - Ok(()) - } + Ok(()) + } - fn get_uiaa_session( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - ) -> Result { - let mut userdevicesessionid = user_id.as_bytes().to_vec(); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(device_id.as_bytes()); - userdevicesessionid.push(0xff); - userdevicesessionid.extend_from_slice(session.as_bytes()); + fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result { + let mut userdevicesessionid = user_id.as_bytes().to_vec(); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(device_id.as_bytes()); + userdevicesessionid.push(0xFF); + userdevicesessionid.extend_from_slice(session.as_bytes()); - serde_json::from_slice( - &self - .userdevicesessionid_uiaainfo - .get(&userdevicesessionid)? - .ok_or(Error::BadRequest( - ErrorKind::Forbidden, - "UIAA session does not exist.", - ))?, - ) - .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) - } + serde_json::from_slice( + &self + .userdevicesessionid_uiaainfo + .get(&userdevicesessionid)? + .ok_or(Error::BadRequest(ErrorKind::Forbidden, "UIAA session does not exist."))?, + ) + .map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid.")) + } } diff --git a/src/database/key_value/users.rs b/src/database/key_value/users.rs index 5601009a..221f35f1 100644 --- a/src/database/key_value/users.rs +++ b/src/database/key_value/users.rs @@ -1,997 +1,835 @@ use std::{collections::BTreeMap, mem::size_of}; use ruma::{ - api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::{AnyToDeviceEvent, StateEventType}, - serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, - OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId, + api::client::{device::Device, error::ErrorKind, filter::FilterDefinition}, + encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, + events::{AnyToDeviceEvent, StateEventType}, + serde::Raw, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId, + OwnedMxcUri, OwnedUserId, UInt, UserId, }; use tracing::warn; use crate::{ - database::KeyValueDatabase, - service::{self, users::clean_signatures}, - services, utils, Error, Result, + database::KeyValueDatabase, + service::{self, users::clean_signatures}, + services, utils, Error, Result, }; impl service::users::Data for KeyValueDatabase { - /// Check if a user has an account on this homeserver. - fn exists(&self, user_id: &UserId) -> Result { - Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) - } - - /// Check if account is deactivated - fn is_deactivated(&self, user_id: &UserId) -> Result { - Ok(self - .userid_password - .get(user_id.as_bytes())? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "User does not exist.", - ))? - .is_empty()) - } - - /// Returns the number of users registered on this server. - fn count(&self) -> Result { - Ok(self.userid_password.iter().count()) - } - - /// Find out which user an access token belongs to. - fn find_from_token(&self, token: &str) -> Result> { - self.token_userdeviceid - .get(token.as_bytes())? - .map_or(Ok(None), |bytes| { - let mut parts = bytes.split(|&b| b == 0xff); - let user_bytes = parts.next().ok_or_else(|| { - Error::bad_database("User ID in token_userdeviceid is invalid.") - })?; - let device_bytes = parts.next().ok_or_else(|| { - Error::bad_database("Device ID in token_userdeviceid is invalid.") - })?; - - Ok(Some(( - UserId::parse(utils::string_from_bytes(user_bytes).map_err(|_| { - Error::bad_database("User ID in token_userdeviceid is invalid unicode.") - })?) - .map_err(|_| { - Error::bad_database("User ID in token_userdeviceid is invalid.") - })?, - utils::string_from_bytes(device_bytes).map_err(|_| { - Error::bad_database("Device ID in token_userdeviceid is invalid.") - })?, - ))) - }) - } - - /// Returns an iterator over all users on this homeserver. - fn iter<'a>(&'a self) -> Box> + 'a> { - Box::new(self.userid_password.iter().map(|(bytes, _)| { - UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("User ID in userid_password is invalid unicode.") - })?) - .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) - })) - } - - /// Returns a list of local users as list of usernames. - /// - /// A user account is considered `local` if the length of it's password is greater then zero. - fn list_local_users(&self) -> Result> { - let users: Vec = self - .userid_password - .iter() - .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) - .collect(); - Ok(users) - } - - /// Returns the password hash for the given user. - fn password_hash(&self, user_id: &UserId) -> Result> { - self.userid_password - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Password hash in db is not valid string.") - })?)) - }) - } - - /// Hash and set the user's password to the Argon2 hash - fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - if let Some(password) = password { - if let Ok(hash) = utils::calculate_password_hash(password) { - self.userid_password - .insert(user_id.as_bytes(), hash.as_bytes())?; - Ok(()) - } else { - Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Password does not meet the requirements.", - )) - } - } else { - self.userid_password.insert(user_id.as_bytes(), b"")?; - Ok(()) - } - } - - /// Returns the displayname of a user on this homeserver. - fn displayname(&self, user_id: &UserId) -> Result> { - self.userid_displayname - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Displayname in db is invalid.") - })?)) - }) - } - - /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. - fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { - if let Some(displayname) = displayname { - self.userid_displayname - .insert(user_id.as_bytes(), displayname.as_bytes())?; - } else { - self.userid_displayname.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the avatar_url of a user. - fn avatar_url(&self, user_id: &UserId) -> Result> { - self.userid_avatarurl - .get(user_id.as_bytes())? - .map(|bytes| { - let s_bytes = utils::string_from_bytes(&bytes).map_err(|e| { - warn!("Avatar URL in db is invalid: {}", e); - Error::bad_database("Avatar URL in db is invalid.") - })?; - let mxc_uri: OwnedMxcUri = s_bytes.into(); - Ok(mxc_uri) - }) - .transpose() - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { - if let Some(avatar_url) = avatar_url { - self.userid_avatarurl - .insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; - } else { - self.userid_avatarurl.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Get the blurhash of a user. - fn blurhash(&self, user_id: &UserId) -> Result> { - self.userid_blurhash - .get(user_id.as_bytes())? - .map(|bytes| { - let s = utils::string_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; - - Ok(s) - }) - .transpose() - } - - /// Sets a new avatar_url or removes it if avatar_url is None. - fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - if let Some(blurhash) = blurhash { - self.userid_blurhash - .insert(user_id.as_bytes(), blurhash.as_bytes())?; - } else { - self.userid_blurhash.remove(user_id.as_bytes())?; - } - - Ok(()) - } - - /// Adds a new device to a user. - fn create_device( - &self, - user_id: &UserId, - device_id: &DeviceId, - token: &str, - initial_device_display_name: Option, - ) -> Result<()> { - // This method should never be called for nonexistent users. We shouldn't assert though... - if !self.exists(user_id)? { - warn!( - "Called create_device for non-existent user {} in database", - user_id - ); - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User does not exist.", - )); - } - - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(&Device { - device_id: device_id.into(), - display_name: initial_device_display_name, - last_seen_ip: None, // TODO - last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), - }) - .expect("Device::to_string never fails."), - )?; - - self.set_token(user_id, device_id, token)?; - - Ok(()) - } - - /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Remove tokens - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.userdeviceid_token.remove(&userdeviceid)?; - self.token_userdeviceid.remove(&old_token)?; - } - - // Remove todevice events - let mut prefix = userdeviceid.clone(); - prefix.push(0xff); - - for (key, _) in self.todeviceid_events.scan_prefix(prefix) { - self.todeviceid_events.remove(&key)?; - } - - // TODO: Remove onetimekeys - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.remove(&userdeviceid)?; - - Ok(()) - } - - /// Returns an iterator over all device ids of this user. - fn all_device_ids<'a>( - &'a self, - user_id: &UserId, - ) -> Box> + 'a> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - // All devices have metadata - Box::new( - self.userdeviceid_metadata - .scan_prefix(prefix) - .map(|(bytes, _)| { - Ok(utils::string_from_bytes( - bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { - Error::bad_database("UserDevice ID in db is invalid.") - })?, - ) - .map_err(|_| { - Error::bad_database("Device ID in userdeviceid_metadata is invalid.") - })? - .into()) - }), - ) - } - - /// Replaces the access token of one device. - fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // should not be None, but we shouldn't assert either lol... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - warn!("Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", user_id, device_id); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); - } - - // Remove old token - if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { - self.token_userdeviceid.remove(&old_token)?; - // It will be removed from userdeviceid_token by the insert later - } - - // Assign token to user device combination - self.userdeviceid_token - .insert(&userdeviceid, token.as_bytes())?; - self.token_userdeviceid - .insert(token.as_bytes(), &userdeviceid)?; - - Ok(()) - } - - fn add_one_time_key( - &self, - user_id: &UserId, - device_id: &DeviceId, - one_time_key_key: &DeviceKeyId, - one_time_key_value: &Raw, - ) -> Result<()> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.as_bytes()); - - // All devices have metadata - // Only existing devices should be able to call this, but we shouldn't assert either... - if self.userdeviceid_metadata.get(&key)?.is_none() { - warn!("Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", user_id, device_id); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); - } - - key.push(0xff); - // TODO: Use DeviceKeyId::to_string when it's available (and update everything, - // because there are no wrapping quotation marks anymore) - key.extend_from_slice( - serde_json::to_string(one_time_key_key) - .expect("DeviceKeyId::to_string always works") - .as_bytes(), - ); - - self.onetimekeyid_onetimekeys.insert( - &key, - &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), - )?; - - self.userid_lastonetimekeyupdate.insert( - user_id.as_bytes(), - &services().globals.next_count()?.to_be_bytes(), - )?; - - Ok(()) - } - - fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.userid_lastonetimekeyupdate - .get(user_id.as_bytes())? - .map(|bytes| { - utils::u64_from_bytes(&bytes).map_err(|_| { - Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.") - }) - }) - .unwrap_or(Ok(0)) - } - - fn take_one_time_key( - &self, - user_id: &UserId, - device_id: &DeviceId, - key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - prefix.push(b'"'); // Annoying quotation mark - prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); - prefix.push(b':'); - - self.userid_lastonetimekeyupdate.insert( - user_id.as_bytes(), - &services().globals.next_count()?.to_be_bytes(), - )?; - - self.onetimekeyid_onetimekeys - .scan_prefix(prefix) - .next() - .map(|(key, value)| { - self.onetimekeyid_onetimekeys.remove(&key)?; - - Ok(( - serde_json::from_slice( - key.rsplit(|&b| b == 0xff) - .next() - .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, - ) - .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, - )) - }) - .transpose() - } - - fn count_one_time_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - let mut counts = BTreeMap::new(); - - for algorithm in - self.onetimekeyid_onetimekeys - .scan_prefix(userdeviceid) - .map(|(bytes, _)| { - Ok::<_, Error>( - serde_json::from_slice::( - bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| { - Error::bad_database("OneTimeKey ID in db is invalid.") - })?, - ) - .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? - .algorithm(), - ) - }) - { - *counts.entry(algorithm?).or_default() += UInt::from(1_u32); - } - - Ok(counts) - } - - fn add_device_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - device_keys: &Raw, - ) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.insert( - &userdeviceid, - &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), - )?; - - self.mark_device_key_update(user_id)?; - - Ok(()) - } - - fn add_cross_signing_keys( - &self, - user_id: &UserId, - master_key: &Raw, - self_signing_key: &Option>, - user_signing_key: &Option>, - notify: bool, - ) -> Result<()> { - // TODO: Check signatures - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - let (master_key_key, _) = self.parse_master_key(user_id, master_key)?; - - self.keyid_key - .insert(&master_key_key, master_key.json().get().as_bytes())?; - - self.userid_masterkeyid - .insert(user_id.as_bytes(), &master_key_key)?; - - // Self-signing key - if let Some(self_signing_key) = self_signing_key { - let mut self_signing_key_ids = self_signing_key - .deserialize() - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key") - })? - .keys - .into_values(); - - let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained no key.", - ))?; - - if self_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Self signing key contained more than one key.", - )); - } - - let mut self_signing_key_key = prefix.clone(); - self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); - - self.keyid_key.insert( - &self_signing_key_key, - self_signing_key.json().get().as_bytes(), - )?; - - self.userid_selfsigningkeyid - .insert(user_id.as_bytes(), &self_signing_key_key)?; - } - - // User-signing key - if let Some(user_signing_key) = user_signing_key { - let mut user_signing_key_ids = user_signing_key - .deserialize() - .map_err(|_| { - Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key") - })? - .keys - .into_values(); - - let user_signing_key_id = user_signing_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained no key.", - ))?; - - if user_signing_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "User signing key contained more than one key.", - )); - } - - let mut user_signing_key_key = prefix; - user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); - - self.keyid_key.insert( - &user_signing_key_key, - user_signing_key.json().get().as_bytes(), - )?; - - self.userid_usersigningkeyid - .insert(user_id.as_bytes(), &user_signing_key_key)?; - } - - if notify { - self.mark_device_key_update(user_id)?; - } - - Ok(()) - } - - fn sign_key( - &self, - target_id: &UserId, - key_id: &str, - signature: (String, String), - sender_id: &UserId, - ) -> Result<()> { - let mut key = target_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(key_id.as_bytes()); - - let mut cross_signing_key: serde_json::Value = - serde_json::from_slice(&self.keyid_key.get(&key)?.ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Tried to sign nonexistent key.", - ))?) - .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; - - let signatures = cross_signing_key - .get_mut("signatures") - .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? - .as_object_mut() - .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? - .entry(sender_id.to_string()) - .or_insert_with(|| serde_json::Map::new().into()); - - signatures - .as_object_mut() - .ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? - .insert(signature.0, signature.1.into()); - - self.keyid_key.insert( - &key, - &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), - )?; - - self.mark_device_key_update(target_id)?; - - Ok(()) - } - - fn keys_changed<'a>( - &'a self, - user_or_room_id: &str, - from: u64, - to: Option, - ) -> Box> + 'a> { - let mut prefix = user_or_room_id.as_bytes().to_vec(); - prefix.push(0xff); - - let mut start = prefix.clone(); - start.extend_from_slice(&(from + 1).to_be_bytes()); - - let to = to.unwrap_or(u64::MAX); - - Box::new( - self.keychangeid_userid - .iter_from(&start, false) - .take_while(move |(k, _)| { - k.starts_with(&prefix) - && if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) { - if let Ok(c) = utils::u64_from_bytes(current) { - c <= to - } else { - warn!("BadDatabase: Could not parse keychangeid_userid bytes"); - false - } - } else { - warn!("BadDatabase: Could not parse keychangeid_userid"); - false - } - }) - .map(|(_, bytes)| { - UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| { - Error::bad_database( - "User ID in devicekeychangeid_userid is invalid unicode.", - ) - })?) - .map_err(|_| { - Error::bad_database("User ID in devicekeychangeid_userid is invalid.") - }) - }), - ) - } - - fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - let count = services().globals.next_count()?.to_be_bytes(); - for room_id in services() - .rooms - .state_cache - .rooms_joined(user_id) - .filter_map(std::result::Result::ok) - { - // Don't send key updates to unencrypted rooms - if services() - .rooms - .state_accessor - .room_state_get(&room_id, &StateEventType::RoomEncryption, "")? - .is_none() - { - continue; - } - - let mut key = room_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - } - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&count); - self.keychangeid_userid.insert(&key, user_id.as_bytes())?; - - Ok(()) - } - - fn get_device_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result>> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(device_id.as_bytes()); - - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("DeviceKeys in db are invalid.") - })?)) - }) - } - - fn parse_master_key( - &self, - user_id: &UserId, - master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - - let master_key = master_key - .deserialize() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; - let mut master_key_ids = master_key.keys.values(); - let master_key_id = master_key_ids.next().ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained no key.", - ))?; - if master_key_ids.next().is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Master key contained more than one key.", - )); - } - let mut master_key_key = prefix.clone(); - master_key_key.extend_from_slice(master_key_id.as_bytes()); - Ok((master_key_key, master_key)) - } - - fn get_key( - &self, - key: &[u8], - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { - let mut cross_signing_key = serde_json::from_slice::(&bytes) - .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; - clean_signatures( - &mut cross_signing_key, - sender_user, - user_id, - allowed_signatures, - )?; - - Ok(Some(Raw::from_json( - serde_json::value::to_raw_value(&cross_signing_key) - .expect("Value to RawValue serialization"), - ))) - }) - } - - fn get_master_key( - &self, - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.userid_masterkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.get_key(&key, sender_user, user_id, allowed_signatures) - }) - } - - fn get_self_signing_key( - &self, - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.userid_selfsigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.get_key(&key, sender_user, user_id, allowed_signatures) - }) - } - - fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.userid_usersigningkeyid - .get(user_id.as_bytes())? - .map_or(Ok(None), |key| { - self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("CrossSigningKey in db is invalid.") - })?)) - }) - }) - } - - fn add_to_device_event( - &self, - sender: &UserId, - target_user_id: &UserId, - target_device_id: &DeviceId, - event_type: &str, - content: serde_json::Value, - ) -> Result<()> { - let mut key = target_user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(target_device_id.as_bytes()); - key.push(0xff); - key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); - - let mut json = serde_json::Map::new(); - json.insert("type".to_owned(), event_type.to_owned().into()); - json.insert("sender".to_owned(), sender.to_string().into()); - json.insert("content".to_owned(), content); - - let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); - - self.todeviceid_events.insert(&key, &value)?; - - Ok(()) - } - - fn get_to_device_events( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result>> { - let mut events = Vec::new(); - - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - - for (_, value) in self.todeviceid_events.scan_prefix(prefix) { - events.push( - serde_json::from_slice(&value) - .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, - ); - } - - Ok(events) - } - - fn remove_to_device_events( - &self, - user_id: &UserId, - device_id: &DeviceId, - until: u64, - ) -> Result<()> { - let mut prefix = user_id.as_bytes().to_vec(); - prefix.push(0xff); - prefix.extend_from_slice(device_id.as_bytes()); - prefix.push(0xff); - - let mut last = prefix.clone(); - last.extend_from_slice(&until.to_be_bytes()); - - for (key, _) in self - .todeviceid_events - .iter_from(&last, true) // this includes last - .take_while(move |(k, _)| k.starts_with(&prefix)) - .map(|(key, _)| { - Ok::<_, Error>(( - key.clone(), - utils::u64_from_bytes(&key[key.len() - size_of::()..key.len()]) - .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, - )) - }) - .filter_map(std::result::Result::ok) - .take_while(|&(_, count)| count <= until) - { - self.todeviceid_events.remove(&key)?; - } - - Ok(()) - } - - fn update_device_metadata( - &self, - user_id: &UserId, - device_id: &DeviceId, - device: &Device, - ) -> Result<()> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - // Only existing devices should be able to call this, but we shouldn't assert either... - if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { - warn!("Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", user_id, device_id); - return Err(Error::bad_database( - "User does not exist or device ID has no metadata in database.", - )); - } - - self.userid_devicelistversion - .increment(user_id.as_bytes())?; - - self.userdeviceid_metadata.insert( - &userdeviceid, - &serde_json::to_vec(device).expect("Device::to_string always works"), - )?; - - Ok(()) - } - - /// Get device metadata. - fn get_device_metadata( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result> { - let mut userdeviceid = user_id.as_bytes().to_vec(); - userdeviceid.push(0xff); - userdeviceid.extend_from_slice(device_id.as_bytes()); - - self.userdeviceid_metadata - .get(&userdeviceid)? - .map_or(Ok(None), |bytes| { - Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { - Error::bad_database("Metadata in userdeviceid_metadata is invalid.") - })?)) - }) - } - - fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.userid_devicelistversion - .get(user_id.as_bytes())? - .map_or(Ok(None), |bytes| { - utils::u64_from_bytes(&bytes) - .map_err(|_| Error::bad_database("Invalid devicelistversion in db.")) - .map(Some) - }) - } - - fn all_devices_metadata<'a>( - &'a self, - user_id: &UserId, - ) -> Box> + 'a> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - - Box::new( - self.userdeviceid_metadata - .scan_prefix(key) - .map(|(_, bytes)| { - serde_json::from_slice::(&bytes).map_err(|_| { - Error::bad_database("Device in userdeviceid_metadata is invalid.") - }) - }), - ) - } - - /// Creates a new sync filter. Returns the filter id. - fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { - let filter_id = utils::random_string(4); - - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(filter_id.as_bytes()); - - self.userfilterid_filter.insert( - &key, - &serde_json::to_vec(&filter).expect("filter is valid json"), - )?; - - Ok(filter_id) - } - - fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { - let mut key = user_id.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(filter_id.as_bytes()); - - let raw = self.userfilterid_filter.get(&key)?; - - if let Some(raw) = raw { - serde_json::from_slice(&raw) - .map_err(|_| Error::bad_database("Invalid filter event in db.")) - } else { - Ok(None) - } - } + /// Check if a user has an account on this homeserver. + fn exists(&self, user_id: &UserId) -> Result { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) } + + /// Check if account is deactivated + fn is_deactivated(&self, user_id: &UserId) -> Result { + Ok(self + .userid_password + .get(user_id.as_bytes())? + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))? + .is_empty()) + } + + /// Returns the number of users registered on this server. + fn count(&self) -> Result { Ok(self.userid_password.iter().count()) } + + /// Find out which user an access token belongs to. + fn find_from_token(&self, token: &str) -> Result> { + self.token_userdeviceid.get(token.as_bytes())?.map_or(Ok(None), |bytes| { + let mut parts = bytes.split(|&b| b == 0xFF); + let user_bytes = + parts.next().ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?; + let device_bytes = + parts.next().ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?; + + Ok(Some(( + UserId::parse( + utils::string_from_bytes(user_bytes) + .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?, + utils::string_from_bytes(device_bytes) + .map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?, + ))) + }) + } + + /// Returns an iterator over all users on this homeserver. + fn iter<'a>(&'a self) -> Box> + 'a> { + Box::new(self.userid_password.iter().map(|(bytes, _)| { + UserId::parse( + utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("User ID in userid_password is invalid unicode."))?, + ) + .map_err(|_| Error::bad_database("User ID in userid_password is invalid.")) + })) + } + + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is + /// greater then zero. + fn list_local_users(&self) -> Result> { + let users: Vec = self + .userid_password + .iter() + .filter_map(|(username, pw)| get_username_with_valid_password(&username, &pw)) + .collect(); + Ok(users) + } + + /// Returns the password hash for the given user. + fn password_hash(&self, user_id: &UserId) -> Result> { + self.userid_password.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| { + Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("Password hash in db is not valid string.") + })?)) + }) + } + + /// Hash and set the user's password to the Argon2 hash + fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + if let Some(password) = password { + if let Ok(hash) = utils::calculate_password_hash(password) { + self.userid_password.insert(user_id.as_bytes(), hash.as_bytes())?; + Ok(()) + } else { + Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Password does not meet the requirements.", + )) + } + } else { + self.userid_password.insert(user_id.as_bytes(), b"")?; + Ok(()) + } + } + + /// Returns the displayname of a user on this homeserver. + fn displayname(&self, user_id: &UserId) -> Result> { + self.userid_displayname.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| { + Ok(Some( + utils::string_from_bytes(&bytes).map_err(|_| Error::bad_database("Displayname in db is invalid."))?, + )) + }) + } + + /// Sets a new displayname or removes it if displayname is None. You still + /// need to nofify all rooms of this change. + fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + if let Some(displayname) = displayname { + self.userid_displayname.insert(user_id.as_bytes(), displayname.as_bytes())?; + } else { + self.userid_displayname.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Get the avatar_url of a user. + fn avatar_url(&self, user_id: &UserId) -> Result> { + self.userid_avatarurl + .get(user_id.as_bytes())? + .map(|bytes| { + let s_bytes = utils::string_from_bytes(&bytes).map_err(|e| { + warn!("Avatar URL in db is invalid: {}", e); + Error::bad_database("Avatar URL in db is invalid.") + })?; + let mxc_uri: OwnedMxcUri = s_bytes.into(); + Ok(mxc_uri) + }) + .transpose() + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { + if let Some(avatar_url) = avatar_url { + self.userid_avatarurl.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?; + } else { + self.userid_avatarurl.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Get the blurhash of a user. + fn blurhash(&self, user_id: &UserId) -> Result> { + self.userid_blurhash + .get(user_id.as_bytes())? + .map(|bytes| { + let s = utils::string_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Avatar URL in db is invalid."))?; + + Ok(s) + }) + .transpose() + } + + /// Sets a new avatar_url or removes it if avatar_url is None. + fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + if let Some(blurhash) = blurhash { + self.userid_blurhash.insert(user_id.as_bytes(), blurhash.as_bytes())?; + } else { + self.userid_blurhash.remove(user_id.as_bytes())?; + } + + Ok(()) + } + + /// Adds a new device to a user. + fn create_device( + &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, + ) -> Result<()> { + // This method should never be called for nonexistent users. We shouldn't assert + // though... + if !self.exists(user_id)? { + warn!("Called create_device for non-existent user {} in database", user_id); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist.")); + } + + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.userid_devicelistversion.increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(&Device { + device_id: device_id.into(), + display_name: initial_device_display_name, + last_seen_ip: None, // TODO + last_seen_ts: Some(MilliSecondsSinceUnixEpoch::now()), + }) + .expect("Device::to_string never fails."), + )?; + + self.set_token(user_id, device_id, token)?; + + Ok(()) + } + + /// Removes a device from a user. + fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Remove tokens + if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { + self.userdeviceid_token.remove(&userdeviceid)?; + self.token_userdeviceid.remove(&old_token)?; + } + + // Remove todevice events + let mut prefix = userdeviceid.clone(); + prefix.push(0xFF); + + for (key, _) in self.todeviceid_events.scan_prefix(prefix) { + self.todeviceid_events.remove(&key)?; + } + + // TODO: Remove onetimekeys + + self.userid_devicelistversion.increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.remove(&userdeviceid)?; + + Ok(()) + } + + /// Returns an iterator over all device ids of this user. + fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + // All devices have metadata + Box::new(self.userdeviceid_metadata.scan_prefix(prefix).map(|(bytes, _)| { + Ok(utils::string_from_bytes( + bytes + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))? + .into()) + })) + } + + /// Replaces the access token of one device. + fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // should not be None, but we shouldn't assert either lol... + if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { + warn!( + "Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", + user_id, device_id + ); + return Err(Error::bad_database( + "User does not exist or device ID has no metadata in database.", + )); + } + + // Remove old token + if let Some(old_token) = self.userdeviceid_token.get(&userdeviceid)? { + self.token_userdeviceid.remove(&old_token)?; + // It will be removed from userdeviceid_token by the insert later + } + + // Assign token to user device combination + self.userdeviceid_token.insert(&userdeviceid, token.as_bytes())?; + self.token_userdeviceid.insert(token.as_bytes(), &userdeviceid)?; + + Ok(()) + } + + fn add_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw, + ) -> Result<()> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + + // All devices have metadata + // Only existing devices should be able to call this, but we shouldn't assert + // either... + if self.userdeviceid_metadata.get(&key)?.is_none() { + warn!( + "Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in \ + database", + user_id, device_id + ); + return Err(Error::bad_database( + "User does not exist or device ID has no metadata in database.", + )); + } + + key.push(0xFF); + // TODO: Use DeviceKeyId::to_string when it's available (and update everything, + // because there are no wrapping quotation marks anymore) + key.extend_from_slice( + serde_json::to_string(one_time_key_key).expect("DeviceKeyId::to_string always works").as_bytes(), + ); + + self.onetimekeyid_onetimekeys.insert( + &key, + &serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"), + )?; + + self.userid_lastonetimekeyupdate.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + + Ok(()) + } + + fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { + self.userid_lastonetimekeyupdate + .get(user_id.as_bytes())? + .map(|bytes| { + utils::u64_from_bytes(&bytes) + .map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")) + }) + .unwrap_or(Ok(0)) + } + + fn take_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, + ) -> Result)>> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + prefix.push(b'"'); // Annoying quotation mark + prefix.extend_from_slice(key_algorithm.as_ref().as_bytes()); + prefix.push(b':'); + + self.userid_lastonetimekeyupdate.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?; + + self.onetimekeyid_onetimekeys + .scan_prefix(prefix) + .next() + .map(|(key, value)| { + self.onetimekeyid_onetimekeys.remove(&key)?; + + Ok(( + serde_json::from_slice( + key.rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("OneTimeKeyId in db is invalid."))?, + serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("OneTimeKeys in db are invalid."))?, + )) + }) + .transpose() + } + + fn count_one_time_keys( + &self, user_id: &UserId, device_id: &DeviceId, + ) -> Result> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + let mut counts = BTreeMap::new(); + + for algorithm in self.onetimekeyid_onetimekeys.scan_prefix(userdeviceid).map(|(bytes, _)| { + Ok::<_, Error>( + serde_json::from_slice::( + bytes + .rsplit(|&b| b == 0xFF) + .next() + .ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?, + ) + .map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))? + .algorithm(), + ) + }) { + *counts.entry(algorithm?).or_default() += UInt::from(1_u32); + } + + Ok(counts) + } + + fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.keyid_key.insert( + &userdeviceid, + &serde_json::to_vec(&device_keys).expect("DeviceKeys::to_vec always works"), + )?; + + self.mark_device_key_update(user_id)?; + + Ok(()) + } + + fn add_cross_signing_keys( + &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, + user_signing_key: &Option>, notify: bool, + ) -> Result<()> { + // TODO: Check signatures + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let (master_key_key, _) = self.parse_master_key(user_id, master_key)?; + + self.keyid_key.insert(&master_key_key, master_key.json().get().as_bytes())?; + + self.userid_masterkeyid.insert(user_id.as_bytes(), &master_key_key)?; + + // Self-signing key + if let Some(self_signing_key) = self_signing_key { + let mut self_signing_key_ids = self_signing_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))? + .keys + .into_values(); + + let self_signing_key_id = self_signing_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?; + + if self_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Self signing key contained more than one key.", + )); + } + + let mut self_signing_key_key = prefix.clone(); + self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes()); + + self.keyid_key.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?; + + self.userid_selfsigningkeyid.insert(user_id.as_bytes(), &self_signing_key_key)?; + } + + // User-signing key + if let Some(user_signing_key) = user_signing_key { + let mut user_signing_key_ids = user_signing_key + .deserialize() + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))? + .keys + .into_values(); + + let user_signing_key_id = user_signing_key_ids + .next() + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?; + + if user_signing_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "User signing key contained more than one key.", + )); + } + + let mut user_signing_key_key = prefix; + user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes()); + + self.keyid_key.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?; + + self.userid_usersigningkeyid.insert(user_id.as_bytes(), &user_signing_key_key)?; + } + + if notify { + self.mark_device_key_update(user_id)?; + } + + Ok(()) + } + + fn sign_key( + &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, + ) -> Result<()> { + let mut key = target_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(key_id.as_bytes()); + + let mut cross_signing_key: serde_json::Value = serde_json::from_slice( + &self + .keyid_key + .get(&key)? + .ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?, + ) + .map_err(|_| Error::bad_database("key in keyid_key is invalid."))?; + + let signatures = cross_signing_key + .get_mut("signatures") + .ok_or_else(|| Error::bad_database("key in keyid_key has no signatures field."))? + .as_object_mut() + .ok_or_else(|| Error::bad_database("key in keyid_key has invalid signatures field."))? + .entry(sender_id.to_string()) + .or_insert_with(|| serde_json::Map::new().into()); + + signatures + .as_object_mut() + .ok_or_else(|| Error::bad_database("signatures in keyid_key for a user is invalid."))? + .insert(signature.0, signature.1.into()); + + self.keyid_key.insert( + &key, + &serde_json::to_vec(&cross_signing_key).expect("CrossSigningKey::to_vec always works"), + )?; + + self.mark_device_key_update(target_id)?; + + Ok(()) + } + + fn keys_changed<'a>( + &'a self, user_or_room_id: &str, from: u64, to: Option, + ) -> Box> + 'a> { + let mut prefix = user_or_room_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let mut start = prefix.clone(); + start.extend_from_slice(&(from + 1).to_be_bytes()); + + let to = to.unwrap_or(u64::MAX); + + Box::new( + self.keychangeid_userid + .iter_from(&start, false) + .take_while(move |(k, _)| { + k.starts_with(&prefix) + && if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) { + if let Ok(c) = utils::u64_from_bytes(current) { + c <= to + } else { + warn!("BadDatabase: Could not parse keychangeid_userid bytes"); + false + } + } else { + warn!("BadDatabase: Could not parse keychangeid_userid"); + false + } + }) + .map(|(_, bytes)| { + UserId::parse( + utils::string_from_bytes(&bytes).map_err(|_| { + Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid.")) + }), + ) + } + + fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { + let count = services().globals.next_count()?.to_be_bytes(); + for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(std::result::Result::ok) { + // Don't send key updates to unencrypted rooms + if services().rooms.state_accessor.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?.is_none() + { + continue; + } + + let mut key = room_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&count); + + self.keychangeid_userid.insert(&key, user_id.as_bytes())?; + } + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&count); + self.keychangeid_userid.insert(&key, user_id.as_bytes())?; + + Ok(()) + } + + fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(device_id.as_bytes()); + + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some( + serde_json::from_slice(&bytes).map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?, + )) + }) + } + + fn parse_master_key( + &self, user_id: &UserId, master_key: &Raw, + ) -> Result<(Vec, CrossSigningKey)> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + + let master_key = + master_key.deserialize().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?; + let mut master_key_ids = master_key.keys.values(); + let master_key_id = + master_key_ids.next().ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?; + if master_key_ids.next().is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Master key contained more than one key.", + )); + } + let mut master_key_key = prefix.clone(); + master_key_key.extend_from_slice(master_key_id.as_bytes()); + Ok((master_key_key, master_key)) + } + + fn get_key( + &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>> { + self.keyid_key.get(key)?.map_or(Ok(None), |bytes| { + let mut cross_signing_key = serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?; + clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?; + + Ok(Some(Raw::from_json( + serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"), + ))) + }) + } + + fn get_master_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>> { + self.userid_masterkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) + } + + fn get_self_signing_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>> { + self.userid_selfsigningkeyid + .get(user_id.as_bytes())? + .map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures)) + } + + fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { + self.userid_usersigningkeyid.get(user_id.as_bytes())?.map_or(Ok(None), |key| { + self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| { + Ok(Some( + serde_json::from_slice(&bytes) + .map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?, + )) + }) + }) + } + + fn add_to_device_event( + &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, + content: serde_json::Value, + ) -> Result<()> { + let mut key = target_user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(target_device_id.as_bytes()); + key.push(0xFF); + key.extend_from_slice(&services().globals.next_count()?.to_be_bytes()); + + let mut json = serde_json::Map::new(); + json.insert("type".to_owned(), event_type.to_owned().into()); + json.insert("sender".to_owned(), sender.to_string().into()); + json.insert("content".to_owned(), content); + + let value = serde_json::to_vec(&json).expect("Map::to_vec always works"); + + self.todeviceid_events.insert(&key, &value)?; + + Ok(()) + } + + fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { + let mut events = Vec::new(); + + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + + for (_, value) in self.todeviceid_events.scan_prefix(prefix) { + events.push( + serde_json::from_slice(&value) + .map_err(|_| Error::bad_database("Event in todeviceid_events is invalid."))?, + ); + } + + Ok(events) + } + + fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { + let mut prefix = user_id.as_bytes().to_vec(); + prefix.push(0xFF); + prefix.extend_from_slice(device_id.as_bytes()); + prefix.push(0xFF); + + let mut last = prefix.clone(); + last.extend_from_slice(&until.to_be_bytes()); + + for (key, _) in self + .todeviceid_events + .iter_from(&last, true) // this includes last + .take_while(move |(k, _)| k.starts_with(&prefix)) + .map(|(key, _)| { + Ok::<_, Error>(( + key.clone(), + utils::u64_from_bytes(&key[key.len() - size_of::()..key.len()]) + .map_err(|_| Error::bad_database("ToDeviceId has invalid count bytes."))?, + )) + }) + .filter_map(std::result::Result::ok) + .take_while(|&(_, count)| count <= until) + { + self.todeviceid_events.remove(&key)?; + } + + Ok(()) + } + + fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + // Only existing devices should be able to call this, but we shouldn't assert + // either... + if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() { + warn!( + "Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \ + metadata in database", + user_id, device_id + ); + return Err(Error::bad_database( + "User does not exist or device ID has no metadata in database.", + )); + } + + self.userid_devicelistversion.increment(user_id.as_bytes())?; + + self.userdeviceid_metadata.insert( + &userdeviceid, + &serde_json::to_vec(device).expect("Device::to_string always works"), + )?; + + Ok(()) + } + + /// Get device metadata. + fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { + let mut userdeviceid = user_id.as_bytes().to_vec(); + userdeviceid.push(0xFF); + userdeviceid.extend_from_slice(device_id.as_bytes()); + + self.userdeviceid_metadata.get(&userdeviceid)?.map_or(Ok(None), |bytes| { + Ok(Some(serde_json::from_slice(&bytes).map_err(|_| { + Error::bad_database("Metadata in userdeviceid_metadata is invalid.") + })?)) + }) + } + + fn get_devicelist_version(&self, user_id: &UserId) -> Result> { + self.userid_devicelistversion.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| { + utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid devicelistversion in db.")).map(Some) + }) + } + + fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box> + 'a> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + + Box::new(self.userdeviceid_metadata.scan_prefix(key).map(|(_, bytes)| { + serde_json::from_slice::(&bytes) + .map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid.")) + })) + } + + /// Creates a new sync filter. Returns the filter id. + fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { + let filter_id = utils::random_string(4); + + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(filter_id.as_bytes()); + + self.userfilterid_filter.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?; + + Ok(filter_id) + } + + fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { + let mut key = user_id.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(filter_id.as_bytes()); + + let raw = self.userfilterid_filter.get(&key)?; + + if let Some(raw) = raw { + serde_json::from_slice(&raw).map_err(|_| Error::bad_database("Invalid filter event in db.")) + } else { + Ok(None) + } + } } impl KeyValueDatabase {} /// Will only return with Some(username) if the password was not empty and the /// username could be successfully parsed. -/// If utils::string_from_bytes(...) returns an error that username will be skipped -/// and the error will be logged. +/// If utils::string_from_bytes(...) returns an error that username will be +/// skipped and the error will be logged. fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option { - // A valid password is not empty - if password.is_empty() { - None - } else { - match utils::string_from_bytes(username) { - Ok(u) => Some(u), - Err(e) => { - warn!( - "Failed to parse username while calling get_local_users(): {}", - e.to_string() - ); - None - } - } - } + // A valid password is not empty + if password.is_empty() { + None + } else { + match utils::string_from_bytes(username) { + Ok(u) => Some(u), + Err(e) => { + warn!("Failed to parse username while calling get_local_users(): {}", e.to_string()); + None + }, + } + } } diff --git a/src/database/mod.rs b/src/database/mod.rs index e9926695..687aa95d 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -1,1217 +1,1150 @@ pub(crate) mod abstraction; pub(crate) mod key_value; -use crate::{ - service::rooms::{edus::presence::presence_handler, timeline::PduCount}, - services, utils, Config, Error, PduEvent, Result, Services, SERVICES, +use std::{ + collections::{BTreeMap, HashMap, HashSet}, + fs::{self}, + io::Write, + mem::size_of, + path::Path, + sync::{Arc, Mutex, RwLock}, + time::Duration, }; + use abstraction::{KeyValueDatabaseEngine, KvTree}; use argon2::{password_hash::SaltString, PasswordHasher, PasswordVerifier}; use itertools::Itertools; use lru_cache::LruCache; use rand::thread_rng; use ruma::{ - api::appservice::Registration, - events::{ - push_rules::{PushRulesEvent, PushRulesEventContent}, - room::message::RoomMessageEventContent, - GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType, - }, - push::Ruleset, - CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, - UserId, + api::appservice::Registration, + events::{ + push_rules::{PushRulesEvent, PushRulesEventContent}, + room::message::RoomMessageEventContent, + GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType, + }, + push::Ruleset, + CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId, }; use serde::Deserialize; -use std::{ - collections::{BTreeMap, HashMap, HashSet}, - fs::{self}, - io::Write, - mem::size_of, - path::Path, - sync::{Arc, Mutex, RwLock}, - time::Duration, -}; use tokio::{sync::mpsc, time::interval}; - use tracing::{debug, error, info, warn}; +use crate::{ + service::rooms::{edus::presence::presence_handler, timeline::PduCount}, + services, utils, Config, Error, PduEvent, Result, Services, SERVICES, +}; + pub struct KeyValueDatabase { - db: Arc, + db: Arc, - //pub globals: globals::Globals, - pub(super) global: Arc, - pub(super) server_signingkeys: Arc, + //pub globals: globals::Globals, + pub(super) global: Arc, + pub(super) server_signingkeys: Arc, - //pub users: users::Users, - pub(super) userid_password: Arc, - pub(super) userid_displayname: Arc, - pub(super) userid_avatarurl: Arc, - pub(super) userid_blurhash: Arc, - pub(super) userdeviceid_token: Arc, - pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists - pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 - pub(super) token_userdeviceid: Arc, + //pub users: users::Users, + pub(super) userid_password: Arc, + pub(super) userid_displayname: Arc, + pub(super) userid_avatarurl: Arc, + pub(super) userid_blurhash: Arc, + pub(super) userdeviceid_token: Arc, + pub(super) userdeviceid_metadata: Arc, // This is also used to check if a device exists + pub(super) userid_devicelistversion: Arc, // DevicelistVersion = u64 + pub(super) token_userdeviceid: Arc, - pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId - pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count - pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count - pub(super) keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) - pub(super) userid_masterkeyid: Arc, - pub(super) userid_selfsigningkeyid: Arc, - pub(super) userid_usersigningkeyid: Arc, + pub(super) onetimekeyid_onetimekeys: Arc, // OneTimeKeyId = UserId + DeviceKeyId + pub(super) userid_lastonetimekeyupdate: Arc, // LastOneTimeKeyUpdate = Count + pub(super) keychangeid_userid: Arc, // KeyChangeId = UserId/RoomId + Count + pub(super) keyid_key: Arc, // KeyId = UserId + KeyId (depends on key type) + pub(super) userid_masterkeyid: Arc, + pub(super) userid_selfsigningkeyid: Arc, + pub(super) userid_usersigningkeyid: Arc, - pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId + pub(super) userfilterid_filter: Arc, // UserFilterId = UserId + FilterId - pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count + pub(super) todeviceid_events: Arc, // ToDeviceId = UserId + DeviceId + Count - //pub uiaa: uiaa::Uiaa, - pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication - pub(super) userdevicesessionid_uiaarequest: - RwLock>, + //pub uiaa: uiaa::Uiaa, + pub(super) userdevicesessionid_uiaainfo: Arc, // User-interactive authentication + pub(super) userdevicesessionid_uiaarequest: + RwLock>, - //pub edus: RoomEdus, - pub(super) readreceiptid_readreceipt: Arc, // ReadReceiptId = RoomId + Count + UserId - pub(super) roomuserid_privateread: Arc, // RoomUserId = Room + User, PrivateRead = Count - pub(super) roomuserid_lastprivatereadupdate: Arc, // LastPrivateReadUpdate = Count - pub(super) typingid_userid: Arc, // TypingId = RoomId + TimeoutTime + Count - pub(super) roomid_lasttypingupdate: Arc, // LastRoomTypingUpdate = Count - pub(super) roomuserid_presence: Arc, + //pub edus: RoomEdus, + pub(super) readreceiptid_readreceipt: Arc, // ReadReceiptId = RoomId + Count + UserId + pub(super) roomuserid_privateread: Arc, // RoomUserId = Room + User, PrivateRead = Count + pub(super) roomuserid_lastprivatereadupdate: Arc, // LastPrivateReadUpdate = Count + pub(super) typingid_userid: Arc, // TypingId = RoomId + TimeoutTime + Count + pub(super) roomid_lasttypingupdate: Arc, // LastRoomTypingUpdate = Count + pub(super) roomuserid_presence: Arc, - //pub rooms: rooms::Rooms, - pub(super) pduid_pdu: Arc, // PduId = ShortRoomId + Count - pub(super) eventid_pduid: Arc, - pub(super) roomid_pduleaves: Arc, - pub(super) alias_roomid: Arc, - pub(super) aliasid_alias: Arc, // AliasId = RoomId + Count - pub(super) publicroomids: Arc, + //pub rooms: rooms::Rooms, + pub(super) pduid_pdu: Arc, // PduId = ShortRoomId + Count + pub(super) eventid_pduid: Arc, + pub(super) roomid_pduleaves: Arc, + pub(super) alias_roomid: Arc, + pub(super) aliasid_alias: Arc, // AliasId = RoomId + Count + pub(super) publicroomids: Arc, - pub(super) threadid_userids: Arc, // ThreadId = RoomId + Count + pub(super) threadid_userids: Arc, // ThreadId = RoomId + Count - pub(super) tokenids: Arc, // TokenId = ShortRoomId + Token + PduIdCount + pub(super) tokenids: Arc, // TokenId = ShortRoomId + Token + PduIdCount - /// Participating servers in a room. - pub(super) roomserverids: Arc, // RoomServerId = RoomId + ServerName - pub(super) serverroomids: Arc, // ServerRoomId = ServerName + RoomId + /// Participating servers in a room. + pub(super) roomserverids: Arc, // RoomServerId = RoomId + ServerName + pub(super) serverroomids: Arc, // ServerRoomId = ServerName + RoomId - pub(super) userroomid_joined: Arc, - pub(super) roomuserid_joined: Arc, - pub(super) roomid_joinedcount: Arc, - pub(super) roomid_invitedcount: Arc, - pub(super) roomuseroncejoinedids: Arc, - pub(super) userroomid_invitestate: Arc, // InviteState = Vec> - pub(super) roomuserid_invitecount: Arc, // InviteCount = Count - pub(super) userroomid_leftstate: Arc, - pub(super) roomuserid_leftcount: Arc, + pub(super) userroomid_joined: Arc, + pub(super) roomuserid_joined: Arc, + pub(super) roomid_joinedcount: Arc, + pub(super) roomid_invitedcount: Arc, + pub(super) roomuseroncejoinedids: Arc, + pub(super) userroomid_invitestate: Arc, // InviteState = Vec> + pub(super) roomuserid_invitecount: Arc, // InviteCount = Count + pub(super) userroomid_leftstate: Arc, + pub(super) roomuserid_leftcount: Arc, - pub(super) disabledroomids: Arc, // Rooms where incoming federation handling is disabled + pub(super) disabledroomids: Arc, // Rooms where incoming federation handling is disabled - pub(super) bannedroomids: Arc, // Rooms where local users are not allowed to join + pub(super) bannedroomids: Arc, // Rooms where local users are not allowed to join - pub(super) lazyloadedids: Arc, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId + pub(super) lazyloadedids: Arc, // LazyLoadedIds = UserId + DeviceId + RoomId + LazyLoadedUserId - pub(super) userroomid_notificationcount: Arc, // NotifyCount = u64 - pub(super) userroomid_highlightcount: Arc, // HightlightCount = u64 - pub(super) roomuserid_lastnotificationread: Arc, // LastNotificationRead = u64 + pub(super) userroomid_notificationcount: Arc, // NotifyCount = u64 + pub(super) userroomid_highlightcount: Arc, // HightlightCount = u64 + pub(super) roomuserid_lastnotificationread: Arc, // LastNotificationRead = u64 - /// Remember the current state hash of a room. - pub(super) roomid_shortstatehash: Arc, - pub(super) roomsynctoken_shortstatehash: Arc, - /// Remember the state hash at events in the past. - pub(super) shorteventid_shortstatehash: Arc, - /// StateKey = EventType + StateKey, ShortStateKey = Count - pub(super) statekey_shortstatekey: Arc, - pub(super) shortstatekey_statekey: Arc, + /// Remember the current state hash of a room. + pub(super) roomid_shortstatehash: Arc, + pub(super) roomsynctoken_shortstatehash: Arc, + /// Remember the state hash at events in the past. + pub(super) shorteventid_shortstatehash: Arc, + /// StateKey = EventType + StateKey, ShortStateKey = Count + pub(super) statekey_shortstatekey: Arc, + pub(super) shortstatekey_statekey: Arc, - pub(super) roomid_shortroomid: Arc, + pub(super) roomid_shortroomid: Arc, - pub(super) shorteventid_eventid: Arc, - pub(super) eventid_shorteventid: Arc, + pub(super) shorteventid_eventid: Arc, + pub(super) eventid_shorteventid: Arc, - pub(super) statehash_shortstatehash: Arc, - pub(super) shortstatehash_statediff: Arc, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--) + pub(super) statehash_shortstatehash: Arc, + pub(super) shortstatehash_statediff: Arc, /* StateDiff = parent (or 0) + + * (shortstatekey+shorteventid++) + 0_u64 + + * (shortstatekey+shorteventid--) */ - pub(super) shorteventid_authchain: Arc, + pub(super) shorteventid_authchain: Arc, - /// RoomId + EventId -> outlier PDU. - /// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn. - pub(super) eventid_outlierpdu: Arc, - pub(super) softfailedeventids: Arc, + /// RoomId + EventId -> outlier PDU. + /// Any pdu that has passed the steps 1-8 in the incoming event + /// /federation/send/txn. + pub(super) eventid_outlierpdu: Arc, + pub(super) softfailedeventids: Arc, - /// ShortEventId + ShortEventId -> (). - pub(super) tofrom_relation: Arc, - /// RoomId + EventId -> Parent PDU EventId. - pub(super) referencedevents: Arc, + /// ShortEventId + ShortEventId -> (). + pub(super) tofrom_relation: Arc, + /// RoomId + EventId -> Parent PDU EventId. + pub(super) referencedevents: Arc, - //pub account_data: account_data::AccountData, - pub(super) roomuserdataid_accountdata: Arc, // RoomUserDataId = Room + User + Count + Type - pub(super) roomusertype_roomuserdataid: Arc, // RoomUserType = Room + User + Type + //pub account_data: account_data::AccountData, + pub(super) roomuserdataid_accountdata: Arc, // RoomUserDataId = Room + User + Count + Type + pub(super) roomusertype_roomuserdataid: Arc, // RoomUserType = Room + User + Type - //pub media: media::Media, - pub(super) mediaid_file: Arc, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType - pub(super) url_previews: Arc, - //pub key_backups: key_backups::KeyBackups, - pub(super) backupid_algorithm: Arc, // BackupId = UserId + Version(Count) - pub(super) backupid_etag: Arc, // BackupId = UserId + Version(Count) - pub(super) backupkeyid_backup: Arc, // BackupKeyId = UserId + Version + RoomId + SessionId + //pub media: media::Media, + pub(super) mediaid_file: Arc, // MediaId = MXC + WidthHeight + ContentDisposition + ContentType + pub(super) url_previews: Arc, + //pub key_backups: key_backups::KeyBackups, + pub(super) backupid_algorithm: Arc, // BackupId = UserId + Version(Count) + pub(super) backupid_etag: Arc, // BackupId = UserId + Version(Count) + pub(super) backupkeyid_backup: Arc, // BackupKeyId = UserId + Version + RoomId + SessionId - //pub transaction_ids: transaction_ids::TransactionIds, - pub(super) userdevicetxnid_response: Arc, // Response can be empty (/sendToDevice) or the event id (/send) - //pub sending: sending::Sending, - pub(super) servername_educount: Arc, // EduCount: Count of last EDU sync - pub(super) servernameevent_data: Arc, // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content - pub(super) servercurrentevent_data: Arc, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content + //pub transaction_ids: transaction_ids::TransactionIds, + pub(super) userdevicetxnid_response: Arc, /* Response can be empty (/sendToDevice) or the event id + * (/send) */ + //pub sending: sending::Sending, + pub(super) servername_educount: Arc, // EduCount: Count of last EDU sync + pub(super) servernameevent_data: Arc, /* ServernameEvent = (+ / $)SenderKey / ServerName / UserId + + * PduId / Id (for edus), Data = EDU content */ + pub(super) servercurrentevent_data: Arc, /* ServerCurrentEvents = (+ / $)ServerName / UserId + PduId + * / Id (for edus), Data = EDU content */ - //pub appservice: appservice::Appservice, - pub(super) id_appserviceregistrations: Arc, + //pub appservice: appservice::Appservice, + pub(super) id_appserviceregistrations: Arc, - //pub pusher: pusher::PushData, - pub(super) senderkey_pusher: Arc, + //pub pusher: pusher::PushData, + pub(super) senderkey_pusher: Arc, - pub(super) cached_registrations: Arc>>, - pub(super) pdu_cache: Mutex>>, - pub(super) shorteventid_cache: Mutex>>, - pub(super) auth_chain_cache: Mutex, Arc>>>, - pub(super) eventidshort_cache: Mutex>, - pub(super) statekeyshort_cache: Mutex>, - pub(super) shortstatekey_cache: Mutex>, - pub(super) our_real_users_cache: RwLock>>>, - pub(super) appservice_in_room_cache: RwLock>>, - pub(super) lasttimelinecount_cache: Mutex>, - pub(super) presence_timer_sender: Arc>, + pub(super) cached_registrations: Arc>>, + pub(super) pdu_cache: Mutex>>, + pub(super) shorteventid_cache: Mutex>>, + pub(super) auth_chain_cache: Mutex, Arc>>>, + pub(super) eventidshort_cache: Mutex>, + pub(super) statekeyshort_cache: Mutex>, + pub(super) shortstatekey_cache: Mutex>, + pub(super) our_real_users_cache: RwLock>>>, + pub(super) appservice_in_room_cache: RwLock>>, + pub(super) lasttimelinecount_cache: Mutex>, + pub(super) presence_timer_sender: Arc>, } impl KeyValueDatabase { - fn check_db_setup(config: &Config) -> Result<()> { - let path = Path::new(&config.database_path); - - let sqlite_exists = path.join("conduit.db").exists(); - let rocksdb_exists = path.join("IDENTITY").exists(); - - let mut count = 0; - - if sqlite_exists { - count += 1; - } - - if rocksdb_exists { - count += 1; - } - - if count > 1 { - warn!("Multiple databases at database_path detected"); - return Ok(()); - } - - if sqlite_exists && config.database_backend != "sqlite" { - return Err(Error::bad_config( - "Found sqlite at database_path, but is not specified in config.", - )); - } - - if rocksdb_exists && config.database_backend != "rocksdb" { - return Err(Error::bad_config( - "Found rocksdb at database_path, but is not specified in config.", - )); - } - - Ok(()) - } - - /// Load an existing database or create a new one. - pub async fn load_or_create(config: Config) -> Result<()> { - Self::check_db_setup(&config)?; - - if !Path::new(&config.database_path).exists() { - debug!("Database path does not exist, assuming this is a new setup and creating it"); - std::fs::create_dir_all(&config.database_path) - .map_err(|e| { - error!("Failed to create database path: {e}"); - Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself or allow conduwuit the permissions to create directories and files.")})?; - } - - let builder: Arc = match &*config.database_backend { - "sqlite" => { - debug!("Got sqlite database backend"); - #[cfg(not(feature = "sqlite"))] - return Err(Error::BadConfig("Database backend not found.")); - #[cfg(feature = "sqlite")] - Arc::new(Arc::::open(&config)?) - } - "rocksdb" => { - debug!("Got rocksdb database backend"); - #[cfg(not(feature = "rocksdb"))] - return Err(Error::BadConfig("Database backend not found.")); - #[cfg(feature = "rocksdb")] - Arc::new(Arc::::open(&config)?) - } - _ => { - return Err(Error::BadConfig("Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.")); - } - }; - - let (presence_sender, presence_receiver) = mpsc::unbounded_channel(); - - let db_raw = Box::new(Self { - db: builder.clone(), - userid_password: builder.open_tree("userid_password")?, - userid_displayname: builder.open_tree("userid_displayname")?, - userid_avatarurl: builder.open_tree("userid_avatarurl")?, - userid_blurhash: builder.open_tree("userid_blurhash")?, - userdeviceid_token: builder.open_tree("userdeviceid_token")?, - userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, - userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, - token_userdeviceid: builder.open_tree("token_userdeviceid")?, - onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, - userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, - keychangeid_userid: builder.open_tree("keychangeid_userid")?, - keyid_key: builder.open_tree("keyid_key")?, - userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, - userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, - userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, - userfilterid_filter: builder.open_tree("userfilterid_filter")?, - todeviceid_events: builder.open_tree("todeviceid_events")?, - - userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, - userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), - readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, - roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt - roomuserid_lastprivatereadupdate: builder - .open_tree("roomuserid_lastprivatereadupdate")?, - typingid_userid: builder.open_tree("typingid_userid")?, - roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?, - roomuserid_presence: builder.open_tree("roomuserid_presence")?, - pduid_pdu: builder.open_tree("pduid_pdu")?, - eventid_pduid: builder.open_tree("eventid_pduid")?, - roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, - - alias_roomid: builder.open_tree("alias_roomid")?, - aliasid_alias: builder.open_tree("aliasid_alias")?, - publicroomids: builder.open_tree("publicroomids")?, - - threadid_userids: builder.open_tree("threadid_userids")?, - - tokenids: builder.open_tree("tokenids")?, - - roomserverids: builder.open_tree("roomserverids")?, - serverroomids: builder.open_tree("serverroomids")?, - userroomid_joined: builder.open_tree("userroomid_joined")?, - roomuserid_joined: builder.open_tree("roomuserid_joined")?, - roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, - roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, - roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, - userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, - roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, - userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, - roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, - - disabledroomids: builder.open_tree("disabledroomids")?, - - bannedroomids: builder.open_tree("bannedroomids")?, - - lazyloadedids: builder.open_tree("lazyloadedids")?, - - userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, - userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, - roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?, - - statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, - shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, - - shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, - - roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, - - shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, - eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, - shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, - shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, - roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, - roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?, - statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, - - eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, - softfailedeventids: builder.open_tree("softfailedeventids")?, - - tofrom_relation: builder.open_tree("tofrom_relation")?, - referencedevents: builder.open_tree("referencedevents")?, - roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, - roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, - mediaid_file: builder.open_tree("mediaid_file")?, - url_previews: builder.open_tree("url_previews")?, - backupid_algorithm: builder.open_tree("backupid_algorithm")?, - backupid_etag: builder.open_tree("backupid_etag")?, - backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, - userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, - servername_educount: builder.open_tree("servername_educount")?, - servernameevent_data: builder.open_tree("servernameevent_data")?, - servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, - id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, - senderkey_pusher: builder.open_tree("senderkey_pusher")?, - global: builder.open_tree("global")?, - server_signingkeys: builder.open_tree("server_signingkeys")?, - - cached_registrations: Arc::new(RwLock::new(HashMap::new())), - pdu_cache: Mutex::new(LruCache::new( - config - .pdu_cache_capacity - .try_into() - .expect("pdu cache capacity fits into usize"), - )), - auth_chain_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - shorteventid_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - eventidshort_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - shortstatekey_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - statekeyshort_cache: Mutex::new(LruCache::new( - (100_000.0 * config.conduit_cache_capacity_modifier) as usize, - )), - our_real_users_cache: RwLock::new(HashMap::new()), - appservice_in_room_cache: RwLock::new(HashMap::new()), - lasttimelinecount_cache: Mutex::new(HashMap::new()), - presence_timer_sender: Arc::new(presence_sender), - }); - - let db = Box::leak(db_raw); - - let services_raw = Box::new(Services::build(db, config)?); - - // This is the first and only time we initialize the SERVICE static - *SERVICES.write().unwrap() = Some(Box::leak(services_raw)); - - // Matrix resource ownership is based on the server name; changing it - // requires recreating the database from scratch. - if services().users.count()? > 0 { - let conduit_user = - UserId::parse_with_server_name("conduit", services().globals.server_name()) - .expect("@conduit:server_name is valid"); - - if !services().users.exists(&conduit_user)? { - error!( - "The {} server user does not exist, and the database is not new.", - conduit_user - ); - return Err(Error::bad_database( - "Cannot reuse an existing database after changing the server name, please delete the old one first." - )); - } - } - - // If the database has any data, perform data migrations before starting - // do not increment the db version if the user is not using sha256_media - let latest_database_version = if cfg!(feature = "sha256_media") { - 14 - } else { - 13 - }; - - if services().users.count()? > 0 { - // MIGRATIONS - if services().globals.database_version()? < 1 { - for (roomserverid, _) in db.roomserverids.iter() { - let mut parts = roomserverid.split(|&b| b == 0xff); - let room_id = parts.next().expect("split always returns one element"); - let servername = match parts.next() { - Some(s) => s, - None => { - error!("Migration: Invalid roomserverid in db."); - continue; - } - }; - let mut serverroomid = servername.to_vec(); - serverroomid.push(0xff); - serverroomid.extend_from_slice(room_id); - - db.serverroomids.insert(&serverroomid, &[])?; - } - - services().globals.bump_database_version(1)?; - - warn!("Migration: 0 -> 1 finished"); - } - - if services().globals.database_version()? < 2 { - // We accidentally inserted hashed versions of "" into the db instead of just "" - for (userid, password) in db.userid_password.iter() { - let salt = SaltString::generate(thread_rng()); - let empty_pass = services() - .globals - .argon - .hash_password(b"", &salt) - .expect("our own password to be properly hashed"); - let empty_hashed_password = services() - .globals - .argon - .verify_password(&password, &empty_pass) - .is_ok(); - - if empty_hashed_password { - db.userid_password.insert(&userid, b"")?; - } - } - - services().globals.bump_database_version(2)?; - - warn!("Migration: 1 -> 2 finished"); - } - - if services().globals.database_version()? < 3 { - // Move media to filesystem - for (key, content) in db.mediaid_file.iter() { - if content.is_empty() { - continue; - } - - #[allow(deprecated)] - let path = services().globals.get_media_file(&key); - let mut file = fs::File::create(path)?; - file.write_all(&content)?; - db.mediaid_file.insert(&key, &[])?; - } - - services().globals.bump_database_version(3)?; - - warn!("Migration: 2 -> 3 finished"); - } - - if services().globals.database_version()? < 4 { - // Add federated users to services() as deactivated - for our_user in services().users.iter() { - let our_user = our_user?; - if services().users.is_deactivated(&our_user)? { - continue; - } - for room in services().rooms.state_cache.rooms_joined(&our_user) { - for user in services().rooms.state_cache.room_members(&room?) { - let user = user?; - if user.server_name() != services().globals.server_name() { - info!(?user, "Migration: creating user"); - services().users.create(&user, None)?; - } - } - } - } - - services().globals.bump_database_version(4)?; - - warn!("Migration: 3 -> 4 finished"); - } - - if services().globals.database_version()? < 5 { - // Upgrade user data store - for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() { - let mut parts = roomuserdataid.split(|&b| b == 0xff); - let room_id = parts.next().unwrap(); - let user_id = parts.next().unwrap(); - let event_type = roomuserdataid.rsplit(|&b| b == 0xff).next().unwrap(); - - let mut key = room_id.to_vec(); - key.push(0xff); - key.extend_from_slice(user_id); - key.push(0xff); - key.extend_from_slice(event_type); - - db.roomusertype_roomuserdataid - .insert(&key, &roomuserdataid)?; - } - - services().globals.bump_database_version(5)?; - - warn!("Migration: 4 -> 5 finished"); - } - - if services().globals.database_version()? < 6 { - // Set room member count - for (roomid, _) in db.roomid_shortstatehash.iter() { - let string = utils::string_from_bytes(&roomid).unwrap(); - let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); - services().rooms.state_cache.update_joined_count(room_id)?; - } - - services().globals.bump_database_version(6)?; - - warn!("Migration: 5 -> 6 finished"); - } - - if services().globals.database_version()? < 7 { - // Upgrade state store - let mut last_roomstates: HashMap = HashMap::new(); - let mut current_sstatehash: Option = None; - let mut current_room = None; - let mut current_state = HashSet::new(); - let mut counter = 0; - - let mut handle_state = - |current_sstatehash: u64, - current_room: &RoomId, - current_state: HashSet<_>, - last_roomstates: &mut HashMap<_, _>| { - counter += 1; - let last_roomsstatehash = last_roomstates.get(current_room); - - let states_parents = last_roomsstatehash.map_or_else( - || Ok(Vec::new()), - |&last_roomsstatehash| { - services() - .rooms - .state_compressor - .load_shortstatehash_info(last_roomsstatehash) - }, - )?; - - let (statediffnew, statediffremoved) = - if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew = current_state - .difference(&parent_stateinfo.1) - .copied() - .collect::>(); - - let statediffremoved = parent_stateinfo - .1 - .difference(¤t_state) - .copied() - .collect::>(); - - (statediffnew, statediffremoved) - } else { - (current_state, HashSet::new()) - }; - - services().rooms.state_compressor.save_state_from_diff( - current_sstatehash, - Arc::new(statediffnew), - Arc::new(statediffremoved), - 2, // every state change is 2 event changes on average - states_parents, - )?; - - /* - let mut tmp = services().rooms.load_shortstatehash_info(¤t_sstatehash)?; - let state = tmp.pop().unwrap(); - println!( - "{}\t{}{:?}: {:?} + {:?} - {:?}", - current_room, - " ".repeat(tmp.len()), - utils::u64_from_bytes(¤t_sstatehash).unwrap(), - tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), - state - .2 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>(), - state - .3 - .iter() - .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) - .collect::>() - ); - */ - - Ok::<_, Error>(()) - }; - - for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() { - let sstatehash = utils::u64_from_bytes(&k[0..size_of::()]) - .expect("number of bytes is correct"); - let sstatekey = k[size_of::()..].to_vec(); - if Some(sstatehash) != current_sstatehash { - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - last_roomstates - .insert(current_room.clone().unwrap(), current_sstatehash); - } - current_state = HashSet::new(); - current_sstatehash = Some(sstatehash); - - let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap(); - let string = utils::string_from_bytes(&event_id).unwrap(); - let event_id = <&EventId>::try_from(string.as_str()).unwrap(); - let pdu = services() - .rooms - .timeline - .get_pdu(event_id) - .unwrap() - .unwrap(); - - if Some(&pdu.room_id) != current_room.as_ref() { - current_room = Some(pdu.room_id.clone()); - } - } - - let mut val = sstatekey; - val.extend_from_slice(&seventid); - current_state.insert(val.try_into().expect("size is correct")); - } - - if let Some(current_sstatehash) = current_sstatehash { - handle_state( - current_sstatehash, - current_room.as_deref().unwrap(), - current_state, - &mut last_roomstates, - )?; - } - - services().globals.bump_database_version(7)?; - - warn!("Migration: 6 -> 7 finished"); - } - - if services().globals.database_version()? < 8 { - // Generate short room ids for all rooms - for (room_id, _) in db.roomid_shortstatehash.iter() { - let shortroomid = services().globals.next_count()?.to_be_bytes(); - db.roomid_shortroomid.insert(&room_id, &shortroomid)?; - info!("Migration: 8"); - } - // Update pduids db layout - let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(2, |&b| b == 0xff); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = db - .roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_key = short_room_id; - new_key.extend_from_slice(count); - - Some((new_key, v)) - }); - - db.pduid_pdu.insert_batch(&mut batch)?; - - let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| { - if !value.starts_with(b"!") { - return None; - } - let mut parts = value.splitn(2, |&b| b == 0xff); - let room_id = parts.next().unwrap(); - let count = parts.next().unwrap(); - - let short_room_id = db - .roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - - let mut new_value = short_room_id; - new_value.extend_from_slice(count); - - Some((k, new_value)) - }); - - db.eventid_pduid.insert_batch(&mut batch2)?; - - services().globals.bump_database_version(8)?; - - warn!("Migration: 7 -> 8 finished"); - } - - if services().globals.database_version()? < 9 { - // Update tokenids db layout - let mut iter = db - .tokenids - .iter() - .filter_map(|(key, _)| { - if !key.starts_with(b"!") { - return None; - } - let mut parts = key.splitn(4, |&b| b == 0xff); - let room_id = parts.next().unwrap(); - let word = parts.next().unwrap(); - let _pdu_id_room = parts.next().unwrap(); - let pdu_id_count = parts.next().unwrap(); - - let short_room_id = db - .roomid_shortroomid - .get(room_id) - .unwrap() - .expect("shortroomid should exist"); - let mut new_key = short_room_id; - new_key.extend_from_slice(word); - new_key.push(0xff); - new_key.extend_from_slice(pdu_id_count); - Some((new_key, Vec::new())) - }) - .peekable(); - - while iter.peek().is_some() { - db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?; - debug!("Inserted smaller batch"); - } - - info!("Deleting starts"); - - let batch2: Vec<_> = db - .tokenids - .iter() - .filter_map(|(key, _)| { - if key.starts_with(b"!") { - Some(key) - } else { - None - } - }) - .collect(); - - for key in batch2 { - db.tokenids.remove(&key)?; - } - - services().globals.bump_database_version(9)?; - - warn!("Migration: 8 -> 9 finished"); - } - - if services().globals.database_version()? < 10 { - // Add other direction for shortstatekeys - for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() { - db.shortstatekey_statekey - .insert(&shortstatekey, &statekey)?; - } - - // Force E2EE device list updates so we can send them over federation - for user_id in services().users.iter().filter_map(std::result::Result::ok) { - services().users.mark_device_key_update(&user_id)?; - } - - services().globals.bump_database_version(10)?; - - warn!("Migration: 9 -> 10 finished"); - } - - if services().globals.database_version()? < 11 { - db.db - .open_tree("userdevicesessionid_uiaarequest")? - .clear()?; - services().globals.bump_database_version(11)?; - - warn!("Migration: 10 -> 11 finished"); - } - - if services().globals.database_version()? < 12 { - for username in services().users.list_local_users()? { - let user = match UserId::parse_with_server_name( - username.clone(), - services().globals.server_name(), - ) { - Ok(u) => u, - Err(e) => { - warn!("Invalid username {username}: {e}"); - continue; - } - }; - - let raw_rules_list = services() - .account_data - .get( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - ) - .unwrap() - .expect("Username is invalid"); - - let mut account_data = - serde_json::from_str::(raw_rules_list.get()).unwrap(); - let rules_list = &mut account_data.content.global; - - //content rule - { - let content_rule_transformation = - [".m.rules.contains_user_name", ".m.rule.contains_user_name"]; - - let rule = rules_list.content.get(content_rule_transformation[0]); - if rule.is_some() { - let mut rule = rule.unwrap().clone(); - rule.rule_id = content_rule_transformation[1].to_owned(); - rules_list - .content - .shift_remove(content_rule_transformation[0]); - rules_list.content.insert(rule); - } - } - - //underride rules - { - let underride_rule_transformation = [ - [".m.rules.call", ".m.rule.call"], - [".m.rules.room_one_to_one", ".m.rule.room_one_to_one"], - [ - ".m.rules.encrypted_room_one_to_one", - ".m.rule.encrypted_room_one_to_one", - ], - [".m.rules.message", ".m.rule.message"], - [".m.rules.encrypted", ".m.rule.encrypted"], - ]; - - for transformation in underride_rule_transformation { - let rule = rules_list.underride.get(transformation[0]); - if let Some(rule) = rule { - let mut rule = rule.clone(); - rule.rule_id = transformation[1].to_owned(); - rules_list.underride.shift_remove(transformation[0]); - rules_list.underride.insert(rule); - } - } - } - - services().account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; - } - - services().globals.bump_database_version(12)?; - - warn!("Migration: 11 -> 12 finished"); - } - - // This migration can be reused as-is anytime the server-default rules are updated. - if services().globals.database_version()? < 13 { - for username in services().users.list_local_users()? { - let user = match UserId::parse_with_server_name( - username.clone(), - services().globals.server_name(), - ) { - Ok(u) => u, - Err(e) => { - warn!("Invalid username {username}: {e}"); - continue; - } - }; - - let raw_rules_list = services() - .account_data - .get( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - ) - .unwrap() - .expect("Username is invalid"); - - let mut account_data = - serde_json::from_str::(raw_rules_list.get()).unwrap(); - - let user_default_rules = ruma::push::Ruleset::server_default(&user); - account_data - .content - .global - .update_with_server_default(user_default_rules); - - services().account_data.update( - None, - &user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(account_data).expect("to json value always works"), - )?; - } - - services().globals.bump_database_version(13)?; - - warn!("Migration: 12 -> 13 finished"); - } - - if services().globals.database_version()? < 14 && cfg!(feature = "sha256_media") { - warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names"); - // Move old media files to new names - for (key, _) in db.mediaid_file.iter() { - // we know that this method is deprecated, but we need to use it to migrate the old files - // to the new location - // - // TODO: remove this once we're sure that all users have migrated - #[allow(deprecated)] - let old_path = services().globals.get_media_file(&key); - let path = services().globals.get_media_file_new(&key); - // move the file to the new location - if old_path.exists() { - tokio::fs::rename(&old_path, &path).await?; - } - } - - services().globals.bump_database_version(14)?; - - warn!("Migration: 13 -> 14 finished"); - } - - assert_eq!( - services().globals.database_version().unwrap(), - latest_database_version, "Failed asserting local database version {} is equal to known latest conduwuit database version {}", services().globals.database_version().unwrap(), latest_database_version - ); - - { - let patterns = &services().globals.config.forbidden_usernames; - if !patterns.is_empty() { - for user in services().users.iter() { - let user_id = user?; - let matches = patterns.matches(user_id.localpart()); - if matches.matched_any() { - warn!( - "User {} matches the following forbidden username patterns: {}", - user_id.to_string(), - matches - .into_iter() - .map(|x| &patterns.patterns()[x]) - .join(", ") - ); - } - } - } - } - - { - let patterns = &services().globals.config.forbidden_room_names; - if !patterns.is_empty() { - for address in services().rooms.metadata.iter_ids() { - let room_id = address?; - let room_aliases = services().rooms.alias.local_aliases_for_room(&room_id); - for room_alias_result in room_aliases { - let room_alias = room_alias_result?; - let matches = patterns.matches(room_alias.alias()); - if matches.matched_any() { - warn!( - "Room with alias {} ({}) matches the following forbidden room name patterns: {}", - room_alias, - &room_id, - matches - .into_iter() - .map(|x| &patterns.patterns()[x]) - .join(", ") - ); - } - } - } - } - } - - info!( - "Loaded {} database with version {}", - services().globals.config.database_backend, - latest_database_version - ); - } else { - services() - .globals - .bump_database_version(latest_database_version)?; - - // Create the admin room and server user on first run - services().admin.create_admin_room().await?; - - warn!( - "Created new {} database with version {}", - services().globals.config.database_backend, - latest_database_version - ); - } - - services().admin.start_handler(); - - // Set emergency access for the conduit user - match set_emergency_access() { - Ok(pwd_set) => { - if pwd_set { - warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"); - services().admin.send_message(RoomMessageEventContent::text_plain("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!")); - } - } - Err(e) => { - error!( - "Could not set the configured emergency password for the conduit user: {}", - e - ); - } - }; - - services().sending.start_handler(); - - Self::start_cleanup_task().await; - if services().globals.allow_check_for_updates() { - Self::start_check_for_updates_task(); - } - if services().globals.allow_local_presence() { - Self::start_presence_handler(presence_receiver).await; - } - - Ok(()) - } - - pub fn flush(&self) -> Result<()> { - let start = std::time::Instant::now(); - - let res = self.db.flush(); - - debug!("flush: took {:?}", start.elapsed()); - - res - } - - #[tracing::instrument] - pub fn start_check_for_updates_task() { - tokio::spawn(async move { - let timer_interval = Duration::from_secs(60 * 60); - let mut i = interval(timer_interval); - loop { - i.tick().await; - let _ = Self::try_handle_updates().await; - } - }); - } - - async fn try_handle_updates() -> Result<()> { - let response = services() - .globals - .default_client() - .get("https://pupbrain.dev/check-for-updates/stable") - .send() - .await?; - - #[derive(Deserialize)] - struct CheckForUpdatesResponseEntry { - id: u64, - date: String, - message: String, - } - #[derive(Deserialize)] - struct CheckForUpdatesResponse { - updates: Vec, - } - - let response = serde_json::from_str::(&response.text().await?) - .map_err(|e| { - error!("Bad check for updates response: {e}"); - Error::BadServerResponse("Bad version check response") - })?; - - let mut last_update_id = services().globals.last_check_for_updates_id()?; - for update in response.updates { - last_update_id = last_update_id.max(update.id); - if update.id > services().globals.last_check_for_updates_id()? { - error!("{}", update.message); - services() - .admin - .send_message(RoomMessageEventContent::text_plain(format!( - "@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}", - update.date, update.message - ))); - } - } - services() - .globals - .update_check_for_updates_id(last_update_id)?; - - Ok(()) - } - - #[tracing::instrument] - pub async fn start_cleanup_task() { - #[cfg(unix)] - use tokio::signal::unix::{signal, SignalKind}; - use tokio::time::Instant; - - let timer_interval = - Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval)); - - fn perform_cleanup() { - let start = Instant::now(); - if let Err(e) = services().globals.cleanup() { - error!(target: "database-cleanup", "Ran into an error during cleanup: {}", e); - } else { - debug!(target: "database-cleanup", "Finished cleanup in {:#?}.", start.elapsed()); - } - } - - tokio::spawn(async move { - let mut i = interval(timer_interval); - #[cfg(unix)] - let mut hangup = signal(SignalKind::hangup()).unwrap(); - let mut ctrl_c = signal(SignalKind::interrupt()).unwrap(); - let mut terminate = signal(SignalKind::terminate()).unwrap(); - - loop { - #[cfg(unix)] - tokio::select! { - _ = i.tick() => { - debug!(target: "database-cleanup", "Timer ticked"); - } - _ = hangup.recv() => { - debug!(target: "database-cleanup","Received SIGHUP"); - } - _ = ctrl_c.recv() => { - debug!(target: "database-cleanup", "Received Ctrl+C, performing last cleanup"); - perform_cleanup(); - } - _ = terminate.recv() => { - debug!(target: "database-cleanup","Received SIGTERM, performing last cleanup"); - perform_cleanup(); - } - }; - #[cfg(not(unix))] - { - i.tick().await; - debug!(target: "database-cleanup", "Timer ticked") - } - perform_cleanup(); - } - }); - } - - pub async fn start_presence_handler( - presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>, - ) { - tokio::spawn(async move { - match presence_handler(presence_timer_receiver).await { - Ok(()) => warn!("Presence maintenance task finished"), - Err(e) => error!("Presence maintenance task finished with error: {e}"), - } - }); - } + fn check_db_setup(config: &Config) -> Result<()> { + let path = Path::new(&config.database_path); + + let sqlite_exists = path.join("conduit.db").exists(); + let rocksdb_exists = path.join("IDENTITY").exists(); + + let mut count = 0; + + if sqlite_exists { + count += 1; + } + + if rocksdb_exists { + count += 1; + } + + if count > 1 { + warn!("Multiple databases at database_path detected"); + return Ok(()); + } + + if sqlite_exists && config.database_backend != "sqlite" { + return Err(Error::bad_config( + "Found sqlite at database_path, but is not specified in config.", + )); + } + + if rocksdb_exists && config.database_backend != "rocksdb" { + return Err(Error::bad_config( + "Found rocksdb at database_path, but is not specified in config.", + )); + } + + Ok(()) + } + + /// Load an existing database or create a new one. + pub async fn load_or_create(config: Config) -> Result<()> { + Self::check_db_setup(&config)?; + + if !Path::new(&config.database_path).exists() { + debug!("Database path does not exist, assuming this is a new setup and creating it"); + std::fs::create_dir_all(&config.database_path).map_err(|e| { + error!("Failed to create database path: {e}"); + Error::BadConfig( + "Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please \ + create the database folder yourself or allow conduwuit the permissions to create directories and \ + files.", + ) + })?; + } + + let builder: Arc = match &*config.database_backend { + "sqlite" => { + debug!("Got sqlite database backend"); + #[cfg(not(feature = "sqlite"))] + return Err(Error::BadConfig("Database backend not found.")); + #[cfg(feature = "sqlite")] + Arc::new(Arc::::open(&config)?) + }, + "rocksdb" => { + debug!("Got rocksdb database backend"); + #[cfg(not(feature = "rocksdb"))] + return Err(Error::BadConfig("Database backend not found.")); + #[cfg(feature = "rocksdb")] + Arc::new(Arc::::open(&config)?) + }, + _ => { + return Err(Error::BadConfig( + "Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.", + )); + }, + }; + + let (presence_sender, presence_receiver) = mpsc::unbounded_channel(); + + let db_raw = Box::new(Self { + db: builder.clone(), + userid_password: builder.open_tree("userid_password")?, + userid_displayname: builder.open_tree("userid_displayname")?, + userid_avatarurl: builder.open_tree("userid_avatarurl")?, + userid_blurhash: builder.open_tree("userid_blurhash")?, + userdeviceid_token: builder.open_tree("userdeviceid_token")?, + userdeviceid_metadata: builder.open_tree("userdeviceid_metadata")?, + userid_devicelistversion: builder.open_tree("userid_devicelistversion")?, + token_userdeviceid: builder.open_tree("token_userdeviceid")?, + onetimekeyid_onetimekeys: builder.open_tree("onetimekeyid_onetimekeys")?, + userid_lastonetimekeyupdate: builder.open_tree("userid_lastonetimekeyupdate")?, + keychangeid_userid: builder.open_tree("keychangeid_userid")?, + keyid_key: builder.open_tree("keyid_key")?, + userid_masterkeyid: builder.open_tree("userid_masterkeyid")?, + userid_selfsigningkeyid: builder.open_tree("userid_selfsigningkeyid")?, + userid_usersigningkeyid: builder.open_tree("userid_usersigningkeyid")?, + userfilterid_filter: builder.open_tree("userfilterid_filter")?, + todeviceid_events: builder.open_tree("todeviceid_events")?, + + userdevicesessionid_uiaainfo: builder.open_tree("userdevicesessionid_uiaainfo")?, + userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()), + readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?, + roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt + roomuserid_lastprivatereadupdate: builder.open_tree("roomuserid_lastprivatereadupdate")?, + typingid_userid: builder.open_tree("typingid_userid")?, + roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?, + roomuserid_presence: builder.open_tree("roomuserid_presence")?, + pduid_pdu: builder.open_tree("pduid_pdu")?, + eventid_pduid: builder.open_tree("eventid_pduid")?, + roomid_pduleaves: builder.open_tree("roomid_pduleaves")?, + + alias_roomid: builder.open_tree("alias_roomid")?, + aliasid_alias: builder.open_tree("aliasid_alias")?, + publicroomids: builder.open_tree("publicroomids")?, + + threadid_userids: builder.open_tree("threadid_userids")?, + + tokenids: builder.open_tree("tokenids")?, + + roomserverids: builder.open_tree("roomserverids")?, + serverroomids: builder.open_tree("serverroomids")?, + userroomid_joined: builder.open_tree("userroomid_joined")?, + roomuserid_joined: builder.open_tree("roomuserid_joined")?, + roomid_joinedcount: builder.open_tree("roomid_joinedcount")?, + roomid_invitedcount: builder.open_tree("roomid_invitedcount")?, + roomuseroncejoinedids: builder.open_tree("roomuseroncejoinedids")?, + userroomid_invitestate: builder.open_tree("userroomid_invitestate")?, + roomuserid_invitecount: builder.open_tree("roomuserid_invitecount")?, + userroomid_leftstate: builder.open_tree("userroomid_leftstate")?, + roomuserid_leftcount: builder.open_tree("roomuserid_leftcount")?, + + disabledroomids: builder.open_tree("disabledroomids")?, + + bannedroomids: builder.open_tree("bannedroomids")?, + + lazyloadedids: builder.open_tree("lazyloadedids")?, + + userroomid_notificationcount: builder.open_tree("userroomid_notificationcount")?, + userroomid_highlightcount: builder.open_tree("userroomid_highlightcount")?, + roomuserid_lastnotificationread: builder.open_tree("userroomid_highlightcount")?, + + statekey_shortstatekey: builder.open_tree("statekey_shortstatekey")?, + shortstatekey_statekey: builder.open_tree("shortstatekey_statekey")?, + + shorteventid_authchain: builder.open_tree("shorteventid_authchain")?, + + roomid_shortroomid: builder.open_tree("roomid_shortroomid")?, + + shortstatehash_statediff: builder.open_tree("shortstatehash_statediff")?, + eventid_shorteventid: builder.open_tree("eventid_shorteventid")?, + shorteventid_eventid: builder.open_tree("shorteventid_eventid")?, + shorteventid_shortstatehash: builder.open_tree("shorteventid_shortstatehash")?, + roomid_shortstatehash: builder.open_tree("roomid_shortstatehash")?, + roomsynctoken_shortstatehash: builder.open_tree("roomsynctoken_shortstatehash")?, + statehash_shortstatehash: builder.open_tree("statehash_shortstatehash")?, + + eventid_outlierpdu: builder.open_tree("eventid_outlierpdu")?, + softfailedeventids: builder.open_tree("softfailedeventids")?, + + tofrom_relation: builder.open_tree("tofrom_relation")?, + referencedevents: builder.open_tree("referencedevents")?, + roomuserdataid_accountdata: builder.open_tree("roomuserdataid_accountdata")?, + roomusertype_roomuserdataid: builder.open_tree("roomusertype_roomuserdataid")?, + mediaid_file: builder.open_tree("mediaid_file")?, + url_previews: builder.open_tree("url_previews")?, + backupid_algorithm: builder.open_tree("backupid_algorithm")?, + backupid_etag: builder.open_tree("backupid_etag")?, + backupkeyid_backup: builder.open_tree("backupkeyid_backup")?, + userdevicetxnid_response: builder.open_tree("userdevicetxnid_response")?, + servername_educount: builder.open_tree("servername_educount")?, + servernameevent_data: builder.open_tree("servernameevent_data")?, + servercurrentevent_data: builder.open_tree("servercurrentevent_data")?, + id_appserviceregistrations: builder.open_tree("id_appserviceregistrations")?, + senderkey_pusher: builder.open_tree("senderkey_pusher")?, + global: builder.open_tree("global")?, + server_signingkeys: builder.open_tree("server_signingkeys")?, + + cached_registrations: Arc::new(RwLock::new(HashMap::new())), + pdu_cache: Mutex::new(LruCache::new( + config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"), + )), + auth_chain_cache: Mutex::new(LruCache::new((100_000.0 * config.conduit_cache_capacity_modifier) as usize)), + shorteventid_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + eventidshort_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + shortstatekey_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + statekeyshort_cache: Mutex::new(LruCache::new( + (100_000.0 * config.conduit_cache_capacity_modifier) as usize, + )), + our_real_users_cache: RwLock::new(HashMap::new()), + appservice_in_room_cache: RwLock::new(HashMap::new()), + lasttimelinecount_cache: Mutex::new(HashMap::new()), + presence_timer_sender: Arc::new(presence_sender), + }); + + let db = Box::leak(db_raw); + + let services_raw = Box::new(Services::build(db, config)?); + + // This is the first and only time we initialize the SERVICE static + *SERVICES.write().unwrap() = Some(Box::leak(services_raw)); + + // Matrix resource ownership is based on the server name; changing it + // requires recreating the database from scratch. + if services().users.count()? > 0 { + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is valid"); + + if !services().users.exists(&conduit_user)? { + error!("The {} server user does not exist, and the database is not new.", conduit_user); + return Err(Error::bad_database( + "Cannot reuse an existing database after changing the server name, please delete the old one \ + first.", + )); + } + } + + // If the database has any data, perform data migrations before starting + // do not increment the db version if the user is not using sha256_media + let latest_database_version = if cfg!(feature = "sha256_media") { + 14 + } else { + 13 + }; + + if services().users.count()? > 0 { + // MIGRATIONS + if services().globals.database_version()? < 1 { + for (roomserverid, _) in db.roomserverids.iter() { + let mut parts = roomserverid.split(|&b| b == 0xFF); + let room_id = parts.next().expect("split always returns one element"); + let servername = match parts.next() { + Some(s) => s, + None => { + error!("Migration: Invalid roomserverid in db."); + continue; + }, + }; + let mut serverroomid = servername.to_vec(); + serverroomid.push(0xFF); + serverroomid.extend_from_slice(room_id); + + db.serverroomids.insert(&serverroomid, &[])?; + } + + services().globals.bump_database_version(1)?; + + warn!("Migration: 0 -> 1 finished"); + } + + if services().globals.database_version()? < 2 { + // We accidentally inserted hashed versions of "" into the db instead of just "" + for (userid, password) in db.userid_password.iter() { + let salt = SaltString::generate(thread_rng()); + let empty_pass = services() + .globals + .argon + .hash_password(b"", &salt) + .expect("our own password to be properly hashed"); + let empty_hashed_password = + services().globals.argon.verify_password(&password, &empty_pass).is_ok(); + + if empty_hashed_password { + db.userid_password.insert(&userid, b"")?; + } + } + + services().globals.bump_database_version(2)?; + + warn!("Migration: 1 -> 2 finished"); + } + + if services().globals.database_version()? < 3 { + // Move media to filesystem + for (key, content) in db.mediaid_file.iter() { + if content.is_empty() { + continue; + } + + #[allow(deprecated)] + let path = services().globals.get_media_file(&key); + let mut file = fs::File::create(path)?; + file.write_all(&content)?; + db.mediaid_file.insert(&key, &[])?; + } + + services().globals.bump_database_version(3)?; + + warn!("Migration: 2 -> 3 finished"); + } + + if services().globals.database_version()? < 4 { + // Add federated users to services() as deactivated + for our_user in services().users.iter() { + let our_user = our_user?; + if services().users.is_deactivated(&our_user)? { + continue; + } + for room in services().rooms.state_cache.rooms_joined(&our_user) { + for user in services().rooms.state_cache.room_members(&room?) { + let user = user?; + if user.server_name() != services().globals.server_name() { + info!(?user, "Migration: creating user"); + services().users.create(&user, None)?; + } + } + } + } + + services().globals.bump_database_version(4)?; + + warn!("Migration: 3 -> 4 finished"); + } + + if services().globals.database_version()? < 5 { + // Upgrade user data store + for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() { + let mut parts = roomuserdataid.split(|&b| b == 0xFF); + let room_id = parts.next().unwrap(); + let user_id = parts.next().unwrap(); + let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap(); + + let mut key = room_id.to_vec(); + key.push(0xFF); + key.extend_from_slice(user_id); + key.push(0xFF); + key.extend_from_slice(event_type); + + db.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?; + } + + services().globals.bump_database_version(5)?; + + warn!("Migration: 4 -> 5 finished"); + } + + if services().globals.database_version()? < 6 { + // Set room member count + for (roomid, _) in db.roomid_shortstatehash.iter() { + let string = utils::string_from_bytes(&roomid).unwrap(); + let room_id = <&RoomId>::try_from(string.as_str()).unwrap(); + services().rooms.state_cache.update_joined_count(room_id)?; + } + + services().globals.bump_database_version(6)?; + + warn!("Migration: 5 -> 6 finished"); + } + + if services().globals.database_version()? < 7 { + // Upgrade state store + let mut last_roomstates: HashMap = HashMap::new(); + let mut current_sstatehash: Option = None; + let mut current_room = None; + let mut current_state = HashSet::new(); + let mut counter = 0; + + let mut handle_state = |current_sstatehash: u64, + current_room: &RoomId, + current_state: HashSet<_>, + last_roomstates: &mut HashMap<_, _>| { + counter += 1; + let last_roomsstatehash = last_roomstates.get(current_room); + + let states_parents = last_roomsstatehash.map_or_else( + || Ok(Vec::new()), + |&last_roomsstatehash| { + services().rooms.state_compressor.load_shortstatehash_info(last_roomsstatehash) + }, + )?; + + let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew = + current_state.difference(&parent_stateinfo.1).copied().collect::>(); + + let statediffremoved = + parent_stateinfo.1.difference(¤t_state).copied().collect::>(); + + (statediffnew, statediffremoved) + } else { + (current_state, HashSet::new()) + }; + + services().rooms.state_compressor.save_state_from_diff( + current_sstatehash, + Arc::new(statediffnew), + Arc::new(statediffremoved), + 2, // every state change is 2 event changes on average + states_parents, + )?; + + /* + let mut tmp = services().rooms.load_shortstatehash_info(¤t_sstatehash)?; + let state = tmp.pop().unwrap(); + println!( + "{}\t{}{:?}: {:?} + {:?} - {:?}", + current_room, + " ".repeat(tmp.len()), + utils::u64_from_bytes(¤t_sstatehash).unwrap(), + tmp.last().map(|b| utils::u64_from_bytes(&b.0).unwrap()), + state + .2 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) + .collect::>(), + state + .3 + .iter() + .map(|b| utils::u64_from_bytes(&b[size_of::()..]).unwrap()) + .collect::>() + ); + */ + + Ok::<_, Error>(()) + }; + + for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() { + let sstatehash = + utils::u64_from_bytes(&k[0..size_of::()]).expect("number of bytes is correct"); + let sstatekey = k[size_of::()..].to_vec(); + if Some(sstatehash) != current_sstatehash { + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_deref().unwrap(), + current_state, + &mut last_roomstates, + )?; + last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash); + } + current_state = HashSet::new(); + current_sstatehash = Some(sstatehash); + + let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap(); + let string = utils::string_from_bytes(&event_id).unwrap(); + let event_id = <&EventId>::try_from(string.as_str()).unwrap(); + let pdu = services().rooms.timeline.get_pdu(event_id).unwrap().unwrap(); + + if Some(&pdu.room_id) != current_room.as_ref() { + current_room = Some(pdu.room_id.clone()); + } + } + + let mut val = sstatekey; + val.extend_from_slice(&seventid); + current_state.insert(val.try_into().expect("size is correct")); + } + + if let Some(current_sstatehash) = current_sstatehash { + handle_state( + current_sstatehash, + current_room.as_deref().unwrap(), + current_state, + &mut last_roomstates, + )?; + } + + services().globals.bump_database_version(7)?; + + warn!("Migration: 6 -> 7 finished"); + } + + if services().globals.database_version()? < 8 { + // Generate short room ids for all rooms + for (room_id, _) in db.roomid_shortstatehash.iter() { + let shortroomid = services().globals.next_count()?.to_be_bytes(); + db.roomid_shortroomid.insert(&room_id, &shortroomid)?; + info!("Migration: 8"); + } + // Update pduids db layout + let mut batch = db.pduid_pdu.iter().filter_map(|(key, v)| { + if !key.starts_with(b"!") { + return None; + } + let mut parts = key.splitn(2, |&b| b == 0xFF); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist"); + + let mut new_key = short_room_id; + new_key.extend_from_slice(count); + + Some((new_key, v)) + }); + + db.pduid_pdu.insert_batch(&mut batch)?; + + let mut batch2 = db.eventid_pduid.iter().filter_map(|(k, value)| { + if !value.starts_with(b"!") { + return None; + } + let mut parts = value.splitn(2, |&b| b == 0xFF); + let room_id = parts.next().unwrap(); + let count = parts.next().unwrap(); + + let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist"); + + let mut new_value = short_room_id; + new_value.extend_from_slice(count); + + Some((k, new_value)) + }); + + db.eventid_pduid.insert_batch(&mut batch2)?; + + services().globals.bump_database_version(8)?; + + warn!("Migration: 7 -> 8 finished"); + } + + if services().globals.database_version()? < 9 { + // Update tokenids db layout + let mut iter = db + .tokenids + .iter() + .filter_map(|(key, _)| { + if !key.starts_with(b"!") { + return None; + } + let mut parts = key.splitn(4, |&b| b == 0xFF); + let room_id = parts.next().unwrap(); + let word = parts.next().unwrap(); + let _pdu_id_room = parts.next().unwrap(); + let pdu_id_count = parts.next().unwrap(); + + let short_room_id = + db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist"); + let mut new_key = short_room_id; + new_key.extend_from_slice(word); + new_key.push(0xFF); + new_key.extend_from_slice(pdu_id_count); + Some((new_key, Vec::new())) + }) + .peekable(); + + while iter.peek().is_some() { + db.tokenids.insert_batch(&mut iter.by_ref().take(1000))?; + debug!("Inserted smaller batch"); + } + + info!("Deleting starts"); + + let batch2: Vec<_> = db + .tokenids + .iter() + .filter_map(|(key, _)| { + if key.starts_with(b"!") { + Some(key) + } else { + None + } + }) + .collect(); + + for key in batch2 { + db.tokenids.remove(&key)?; + } + + services().globals.bump_database_version(9)?; + + warn!("Migration: 8 -> 9 finished"); + } + + if services().globals.database_version()? < 10 { + // Add other direction for shortstatekeys + for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() { + db.shortstatekey_statekey.insert(&shortstatekey, &statekey)?; + } + + // Force E2EE device list updates so we can send them over federation + for user_id in services().users.iter().filter_map(std::result::Result::ok) { + services().users.mark_device_key_update(&user_id)?; + } + + services().globals.bump_database_version(10)?; + + warn!("Migration: 9 -> 10 finished"); + } + + if services().globals.database_version()? < 11 { + db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?; + services().globals.bump_database_version(11)?; + + warn!("Migration: 10 -> 11 finished"); + } + + if services().globals.database_version()? < 12 { + for username in services().users.list_local_users()? { + let user = match UserId::parse_with_server_name(username.clone(), services().globals.server_name()) + { + Ok(u) => u, + Err(e) => { + warn!("Invalid username {username}: {e}"); + continue; + }, + }; + + let raw_rules_list = services() + .account_data + .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) + .unwrap() + .expect("Username is invalid"); + + let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); + let rules_list = &mut account_data.content.global; + + //content rule + { + let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"]; + + let rule = rules_list.content.get(content_rule_transformation[0]); + if rule.is_some() { + let mut rule = rule.unwrap().clone(); + rule.rule_id = content_rule_transformation[1].to_owned(); + rules_list.content.shift_remove(content_rule_transformation[0]); + rules_list.content.insert(rule); + } + } + + //underride rules + { + let underride_rule_transformation = [ + [".m.rules.call", ".m.rule.call"], + [".m.rules.room_one_to_one", ".m.rule.room_one_to_one"], + [".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one"], + [".m.rules.message", ".m.rule.message"], + [".m.rules.encrypted", ".m.rule.encrypted"], + ]; + + for transformation in underride_rule_transformation { + let rule = rules_list.underride.get(transformation[0]); + if let Some(rule) = rule { + let mut rule = rule.clone(); + rule.rule_id = transformation[1].to_owned(); + rules_list.underride.shift_remove(transformation[0]); + rules_list.underride.insert(rule); + } + } + } + + services().account_data.update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + )?; + } + + services().globals.bump_database_version(12)?; + + warn!("Migration: 11 -> 12 finished"); + } + + // This migration can be reused as-is anytime the server-default rules are + // updated. + if services().globals.database_version()? < 13 { + for username in services().users.list_local_users()? { + let user = match UserId::parse_with_server_name(username.clone(), services().globals.server_name()) + { + Ok(u) => u, + Err(e) => { + warn!("Invalid username {username}: {e}"); + continue; + }, + }; + + let raw_rules_list = services() + .account_data + .get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into()) + .unwrap() + .expect("Username is invalid"); + + let mut account_data = serde_json::from_str::(raw_rules_list.get()).unwrap(); + + let user_default_rules = ruma::push::Ruleset::server_default(&user); + account_data.content.global.update_with_server_default(user_default_rules); + + services().account_data.update( + None, + &user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(account_data).expect("to json value always works"), + )?; + } + + services().globals.bump_database_version(13)?; + + warn!("Migration: 12 -> 13 finished"); + } + + if services().globals.database_version()? < 14 && cfg!(feature = "sha256_media") { + warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names"); + // Move old media files to new names + for (key, _) in db.mediaid_file.iter() { + // we know that this method is deprecated, but we need to use it to migrate the + // old files to the new location + // + // TODO: remove this once we're sure that all users have migrated + #[allow(deprecated)] + let old_path = services().globals.get_media_file(&key); + let path = services().globals.get_media_file_new(&key); + // move the file to the new location + if old_path.exists() { + tokio::fs::rename(&old_path, &path).await?; + } + } + + services().globals.bump_database_version(14)?; + + warn!("Migration: 13 -> 14 finished"); + } + + assert_eq!( + services().globals.database_version().unwrap(), + latest_database_version, + "Failed asserting local database version {} is equal to known latest conduwuit database version {}", + services().globals.database_version().unwrap(), + latest_database_version + ); + + { + let patterns = &services().globals.config.forbidden_usernames; + if !patterns.is_empty() { + for user in services().users.iter() { + let user_id = user?; + let matches = patterns.matches(user_id.localpart()); + if matches.matched_any() { + warn!( + "User {} matches the following forbidden username patterns: {}", + user_id.to_string(), + matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ") + ); + } + } + } + } + + { + let patterns = &services().globals.config.forbidden_room_names; + if !patterns.is_empty() { + for address in services().rooms.metadata.iter_ids() { + let room_id = address?; + let room_aliases = services().rooms.alias.local_aliases_for_room(&room_id); + for room_alias_result in room_aliases { + let room_alias = room_alias_result?; + let matches = patterns.matches(room_alias.alias()); + if matches.matched_any() { + warn!( + "Room with alias {} ({}) matches the following forbidden room name patterns: {}", + room_alias, + &room_id, + matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ") + ); + } + } + } + } + } + + info!( + "Loaded {} database with version {}", + services().globals.config.database_backend, + latest_database_version + ); + } else { + services().globals.bump_database_version(latest_database_version)?; + + // Create the admin room and server user on first run + services().admin.create_admin_room().await?; + + warn!( + "Created new {} database with version {}", + services().globals.config.database_backend, + latest_database_version + ); + } + + services().admin.start_handler(); + + // Set emergency access for the conduit user + match set_emergency_access() { + Ok(pwd_set) => { + if pwd_set { + warn!( + "The Conduit account emergency password is set! Please unset it as soon as you finish admin \ + account recovery!" + ); + services().admin.send_message(RoomMessageEventContent::text_plain( + "The Conduit account emergency password is set! Please unset it as soon as you finish admin \ + account recovery!", + )); + } + }, + Err(e) => { + error!("Could not set the configured emergency password for the conduit user: {}", e); + }, + }; + + services().sending.start_handler(); + + Self::start_cleanup_task().await; + if services().globals.allow_check_for_updates() { + Self::start_check_for_updates_task(); + } + if services().globals.allow_local_presence() { + Self::start_presence_handler(presence_receiver).await; + } + + Ok(()) + } + + pub fn flush(&self) -> Result<()> { + let start = std::time::Instant::now(); + + let res = self.db.flush(); + + debug!("flush: took {:?}", start.elapsed()); + + res + } + + #[tracing::instrument] + pub fn start_check_for_updates_task() { + tokio::spawn(async move { + let timer_interval = Duration::from_secs(60 * 60); + let mut i = interval(timer_interval); + loop { + i.tick().await; + let _ = Self::try_handle_updates().await; + } + }); + } + + async fn try_handle_updates() -> Result<()> { + let response = + services().globals.default_client().get("https://pupbrain.dev/check-for-updates/stable").send().await?; + + #[derive(Deserialize)] + struct CheckForUpdatesResponseEntry { + id: u64, + date: String, + message: String, + } + #[derive(Deserialize)] + struct CheckForUpdatesResponse { + updates: Vec, + } + + let response = serde_json::from_str::(&response.text().await?).map_err(|e| { + error!("Bad check for updates response: {e}"); + Error::BadServerResponse("Bad version check response") + })?; + + let mut last_update_id = services().globals.last_check_for_updates_id()?; + for update in response.updates { + last_update_id = last_update_id.max(update.id); + if update.id > services().globals.last_check_for_updates_id()? { + error!("{}", update.message); + services().admin.send_message(RoomMessageEventContent::text_plain(format!( + "@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}", + update.date, update.message + ))); + } + } + services().globals.update_check_for_updates_id(last_update_id)?; + + Ok(()) + } + + #[tracing::instrument] + pub async fn start_cleanup_task() { + #[cfg(unix)] + use tokio::signal::unix::{signal, SignalKind}; + use tokio::time::Instant; + + let timer_interval = Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval)); + + fn perform_cleanup() { + let start = Instant::now(); + if let Err(e) = services().globals.cleanup() { + error!(target: "database-cleanup", "Ran into an error during cleanup: {}", e); + } else { + debug!(target: "database-cleanup", "Finished cleanup in {:#?}.", start.elapsed()); + } + } + + tokio::spawn(async move { + let mut i = interval(timer_interval); + #[cfg(unix)] + let mut hangup = signal(SignalKind::hangup()).unwrap(); + let mut ctrl_c = signal(SignalKind::interrupt()).unwrap(); + let mut terminate = signal(SignalKind::terminate()).unwrap(); + + loop { + #[cfg(unix)] + tokio::select! { + _ = i.tick() => { + debug!(target: "database-cleanup", "Timer ticked"); + } + _ = hangup.recv() => { + debug!(target: "database-cleanup","Received SIGHUP"); + } + _ = ctrl_c.recv() => { + debug!(target: "database-cleanup", "Received Ctrl+C, performing last cleanup"); + perform_cleanup(); + } + _ = terminate.recv() => { + debug!(target: "database-cleanup","Received SIGTERM, performing last cleanup"); + perform_cleanup(); + } + }; + #[cfg(not(unix))] + { + i.tick().await; + debug!(target: "database-cleanup", "Timer ticked") + } + perform_cleanup(); + } + }); + } + + pub async fn start_presence_handler(presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>) { + tokio::spawn(async move { + match presence_handler(presence_timer_receiver).await { + Ok(()) => warn!("Presence maintenance task finished"), + Err(e) => error!("Presence maintenance task finished with error: {e}"), + } + }); + } } -/// Sets the emergency password and push rules for the @conduit account in case emergency password is set +/// Sets the emergency password and push rules for the @conduit account in case +/// emergency password is set fn set_emergency_access() -> Result { - let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) - .expect("@conduit:server_name is a valid UserId"); + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is a valid UserId"); - services().users.set_password( - &conduit_user, - services().globals.emergency_password().as_deref(), - )?; + services().users.set_password(&conduit_user, services().globals.emergency_password().as_deref())?; - let (ruleset, res) = match services().globals.emergency_password() { - Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)), - None => (Ruleset::new(), Ok(false)), - }; + let (ruleset, res) = match services().globals.emergency_password() { + Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)), + None => (Ruleset::new(), Ok(false)), + }; - services().account_data.update( - None, - &conduit_user, - GlobalAccountDataEventType::PushRules.to_string().into(), - &serde_json::to_value(&GlobalAccountDataEvent { - content: PushRulesEventContent { global: ruleset }, - }) - .expect("to json value always works"), - )?; + services().account_data.update( + None, + &conduit_user, + GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(&GlobalAccountDataEvent { + content: PushRulesEventContent { + global: ruleset, + }, + }) + .expect("to json value always works"), + )?; - res + res } diff --git a/src/lib.rs b/src/lib.rs index 0f808fb7..45b43e9d 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -15,8 +15,5 @@ pub use utils::error::{Error, Result}; pub static SERVICES: RwLock>> = RwLock::new(None); pub fn services() -> &'static Services<'static> { - SERVICES - .read() - .unwrap() - .expect("SERVICES should be initialized when this is called") + SERVICES.read().unwrap().expect("SERVICES should be initialized when this is called") } diff --git a/src/main.rs b/src/main.rs index baa9fc98..1b8a7179 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,56 +1,55 @@ use std::{ - fs::Permissions, future::Future, io, net::SocketAddr, os::unix::fs::PermissionsExt, path::Path, - sync::atomic, time::Duration, + fs::Permissions, future::Future, io, net::SocketAddr, os::unix::fs::PermissionsExt, path::Path, sync::atomic, + time::Duration, }; use axum::{ - extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, - response::IntoResponse, - routing::{get, on, MethodFilter}, - Router, + extract::{DefaultBodyLimit, FromRequestParts, MatchedPath}, + response::IntoResponse, + routing::{get, on, MethodFilter}, + Router, }; use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle}; +#[cfg(feature = "axum_dual_protocol")] +use axum_server_dual_protocol::ServerExt; +use clap::Parser; use conduit::api::{client_server, server_server}; +pub use conduit::*; // Re-export everything from the library crate use either::Either::{Left, Right}; use figment::{ - providers::{Env, Format, Toml}, - Figment, + providers::{Env, Format, Toml}, + Figment, }; use http::{ - header::{self, HeaderName}, - Method, StatusCode, Uri, + header::{self, HeaderName}, + Method, StatusCode, Uri, }; use hyper::Server; use hyperlocal::SocketIncoming; use ruma::api::{ - client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::UiaaResponse, - }, - IncomingRequest, + client::{ + error::{Error as RumaError, ErrorBody, ErrorKind}, + uiaa::UiaaResponse, + }, + IncomingRequest, +}; +#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] +use tikv_jemallocator::Jemalloc; +use tokio::{ + net::UnixListener, + signal, + sync::{oneshot, oneshot::Sender}, + task::JoinSet, }; -use tokio::{net::UnixListener, signal, sync::oneshot, task::JoinSet}; use tower::ServiceBuilder; use tower_http::{ - cors::{self, CorsLayer}, - trace::{DefaultOnFailure, TraceLayer}, - ServiceBuilderExt as _, + cors::{self, CorsLayer}, + trace::{DefaultOnFailure, TraceLayer}, + ServiceBuilderExt as _, }; use tracing::{debug, error, info, warn, Level}; use tracing_subscriber::{prelude::*, EnvFilter}; -use tokio::sync::oneshot::Sender; - -use clap::Parser; - -#[cfg(feature = "axum_dual_protocol")] -use axum_server_dual_protocol::ServerExt; - -pub use conduit::*; // Re-export everything from the library crate - -#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] -use tikv_jemallocator::Jemalloc; - #[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))] #[global_allocator] static GLOBAL: Jemalloc = Jemalloc; @@ -61,721 +60,708 @@ struct Args; #[tokio::main] async fn main() { - Args::parse(); - // Initialize config - let raw_config = if Env::var("CONDUIT_CONFIG").is_some() { - Figment::new() - .merge( - Toml::file(Env::var("CONDUIT_CONFIG").expect( - "The CONDUIT_CONFIG environment variable was set but appears to be invalid. This should be set to the path to a valid TOML file, an empty string (for compatibility), or removed/unset entirely.", - )) - .nested(), - ) - .merge(Env::prefixed("CONDUIT_").global()) - } else { - Figment::new().merge(Env::prefixed("CONDUIT_").global()) - }; + Args::parse(); + // Initialize config + let raw_config = if Env::var("CONDUIT_CONFIG").is_some() { + Figment::new() + .merge( + Toml::file(Env::var("CONDUIT_CONFIG").expect( + "The CONDUIT_CONFIG environment variable was set but appears to be invalid. This should be set to \ + the path to a valid TOML file, an empty string (for compatibility), or removed/unset entirely.", + )) + .nested(), + ) + .merge(Env::prefixed("CONDUIT_").global()) + } else { + Figment::new().merge(Env::prefixed("CONDUIT_").global()) + }; - let config = match raw_config.extract::() { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occurred: {e}"); - return; - } - }; + let config = match raw_config.extract::() { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your config is invalid. The following error occurred: {e}"); + return; + }, + }; - if config.allow_jaeger { - opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); - let tracer = opentelemetry_jaeger::new_agent_pipeline() - .with_auto_split_batch(true) - .with_service_name("conduit") - .install_batch(opentelemetry_sdk::runtime::Tokio) - .unwrap(); - let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); + if config.allow_jaeger { + opentelemetry::global::set_text_map_propagator(opentelemetry_jaeger::Propagator::new()); + let tracer = opentelemetry_jaeger::new_agent_pipeline() + .with_auto_split_batch(true) + .with_service_name("conduit") + .install_batch(opentelemetry_sdk::runtime::Tokio) + .unwrap(); + let telemetry = tracing_opentelemetry::layer().with_tracer(tracer); - let filter_layer = match EnvFilter::try_new(&config.log) { - Ok(s) => s, - Err(e) => { - eprintln!( - "It looks like your log config is invalid. The following error occurred: {e}" - ); - EnvFilter::try_new("warn").unwrap() - } - }; + let filter_layer = match EnvFilter::try_new(&config.log) { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your log config is invalid. The following error occurred: {e}"); + EnvFilter::try_new("warn").unwrap() + }, + }; - let subscriber = tracing_subscriber::Registry::default() - .with(filter_layer) - .with(telemetry); - tracing::subscriber::set_global_default(subscriber).unwrap(); - } else if config.tracing_flame { - let registry = tracing_subscriber::Registry::default(); - let (flame_layer, _guard) = - tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap(); - let flame_layer = flame_layer.with_empty_samples(false); + let subscriber = tracing_subscriber::Registry::default().with(filter_layer).with(telemetry); + tracing::subscriber::set_global_default(subscriber).unwrap(); + } else if config.tracing_flame { + let registry = tracing_subscriber::Registry::default(); + let (flame_layer, _guard) = tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap(); + let flame_layer = flame_layer.with_empty_samples(false); - let filter_layer = EnvFilter::new("trace,h2=off"); + let filter_layer = EnvFilter::new("trace,h2=off"); - let subscriber = registry.with(filter_layer).with(flame_layer); - tracing::subscriber::set_global_default(subscriber).unwrap(); - } else { - let registry = tracing_subscriber::Registry::default(); - let fmt_layer = tracing_subscriber::fmt::Layer::new(); - let filter_layer = match EnvFilter::try_new(&config.log) { - Ok(s) => s, - Err(e) => { - eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); - EnvFilter::try_new("warn").unwrap() - } - }; + let subscriber = registry.with(filter_layer).with(flame_layer); + tracing::subscriber::set_global_default(subscriber).unwrap(); + } else { + let registry = tracing_subscriber::Registry::default(); + let fmt_layer = tracing_subscriber::fmt::Layer::new(); + let filter_layer = match EnvFilter::try_new(&config.log) { + Ok(s) => s, + Err(e) => { + eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}"); + EnvFilter::try_new("warn").unwrap() + }, + }; - let subscriber = registry.with(filter_layer).with(fmt_layer); - tracing::subscriber::set_global_default(subscriber).unwrap(); - } + let subscriber = registry.with(filter_layer).with(fmt_layer); + tracing::subscriber::set_global_default(subscriber).unwrap(); + } - // This is needed for opening lots of file descriptors, which tends to - // happen more often when using RocksDB and making lots of federation - // connections at startup. The soft limit is usually 1024, and the hard - // limit is usually 512000; I've personally seen it hit >2000. - // - // * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 - // * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 - #[cfg(unix)] - maximize_fd_limit().expect("Unable to increase maximum soft and hard file descriptor limit"); + // This is needed for opening lots of file descriptors, which tends to + // happen more often when using RocksDB and making lots of federation + // connections at startup. The soft limit is usually 1024, and the hard + // limit is usually 512000; I've personally seen it hit >2000. + // + // * https://www.freedesktop.org/software/systemd/man/systemd.exec.html#id-1.12.2.1.17.6 + // * https://github.com/systemd/systemd/commit/0abf94923b4a95a7d89bc526efc84e7ca2b71741 + #[cfg(unix)] + maximize_fd_limit().expect("Unable to increase maximum soft and hard file descriptor limit"); - config.warn_deprecated(); - config.warn_unknown_key(); + config.warn_deprecated(); + config.warn_unknown_key(); - // don't start if we're listening on both UNIX sockets and TCP at same time - if config.is_dual_listening(raw_config) { - return; - }; + // don't start if we're listening on both UNIX sockets and TCP at same time + if config.is_dual_listening(raw_config) { + return; + }; - info!("Loading database"); - let db_load_time = std::time::Instant::now(); - if let Err(error) = KeyValueDatabase::load_or_create(config).await { - error!(?error, "The database couldn't be loaded or created"); - return; - }; - info!("Database took {:?} to load", db_load_time.elapsed()); + info!("Loading database"); + let db_load_time = std::time::Instant::now(); + if let Err(error) = KeyValueDatabase::load_or_create(config).await { + error!(?error, "The database couldn't be loaded or created"); + return; + }; + info!("Database took {:?} to load", db_load_time.elapsed()); - let config = &services().globals.config; + let config = &services().globals.config; - /* ad-hoc config validation/checks */ + /* ad-hoc config validation/checks */ - if config.address.is_loopback() { - debug!( - "Found loopback listening address {}, running checks if we're in a container.", - config.address - ); + if config.address.is_loopback() { + debug!( + "Found loopback listening address {}, running checks if we're in a container.", + config.address + ); - #[cfg(unix)] - if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists() - /* Host */ - { - error!("You are detected using OpenVZ with a loopback/localhost listening address of {}. If you are using OpenVZ for containers and you use NAT-based networking to communicate with the host and guest, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", config.address); - } + #[cfg(unix)] + if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists() + /* Host */ + { + error!( + "You are detected using OpenVZ with a loopback/localhost listening address of {}. If you are using \ + OpenVZ for containers and you use NAT-based networking to communicate with the host and guest, this \ + will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", + config.address + ); + } - #[cfg(unix)] - if Path::new("/.dockerenv").exists() { - error!("You are detected using Docker with a loopback/localhost listening address of {}. If you are using a reverse proxy on the host and require communication to conduwuit in the Docker container via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", config.address); - } + #[cfg(unix)] + if Path::new("/.dockerenv").exists() { + error!( + "You are detected using Docker with a loopback/localhost listening address of {}. If you are using a \ + reverse proxy on the host and require communication to conduwuit in the Docker container via \ + NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \ + you can ignore.", + config.address + ); + } - #[cfg(unix)] - if Path::new("/run/.containerenv").exists() { - error!("You are detected using Podman with a loopback/localhost listening address of {}. If you are using a reverse proxy on the host and require communication to conduwuit in the Podman container via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", config.address); - } - } + #[cfg(unix)] + if Path::new("/run/.containerenv").exists() { + error!( + "You are detected using Podman with a loopback/localhost listening address of {}. If you are using a \ + reverse proxy on the host and require communication to conduwuit in the Podman container via \ + NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \ + you can ignore.", + config.address + ); + } + } - // yeah, unless the user built a debug build hopefully for local testing only - if config.server_name == "your.server.name" && !cfg!(debug_assertions) { - error!("You must specify a valid server name for production usage of conduwuit."); - return; - } + // yeah, unless the user built a debug build hopefully for local testing only + if config.server_name == "your.server.name" && !cfg!(debug_assertions) { + error!("You must specify a valid server name for production usage of conduwuit."); + return; + } - if cfg!(debug_assertions) { - info!("Note: conduwuit was built without optimisations (i.e. debug build)"); - } + if cfg!(debug_assertions) { + info!("Note: conduwuit was built without optimisations (i.e. debug build)"); + } - // check if the user specified a registration token as `""` - if config.registration_token == Some(String::new()) { - error!("Registration token was specified but is empty (\"\")"); - return; - } + // check if the user specified a registration token as `""` + if config.registration_token == Some(String::new()) { + error!("Registration token was specified but is empty (\"\")"); + return; + } - if config.max_request_size < 4096 { - error!(?config.max_request_size, "Max request size is less than 4KB. Please increase it."); - } + if config.max_request_size < 4096 { + error!(?config.max_request_size, "Max request size is less than 4KB. Please increase it."); + } - // check if user specified valid IP CIDR ranges on startup - for cidr in services().globals.ip_range_denylist() { - let _ = ipaddress::IPAddress::parse(cidr) - .map_err(|e| error!("Error parsing specified IP CIDR range: {e}")); - } + // check if user specified valid IP CIDR ranges on startup + for cidr in services().globals.ip_range_denylist() { + let _ = ipaddress::IPAddress::parse(cidr).map_err(|e| error!("Error parsing specified IP CIDR range: {e}")); + } - if config.allow_registration - && !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse - && config.registration_token.is_none() - { - error!("!! You have `allow_registration` enabled without a token configured in your config which means you are allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n - If this is not the intended behaviour, please set a registration token with the `registration_token` config option.\n - For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour you want, please set the following config option to true: - `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`"); - return; - } + if config.allow_registration + && !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse + && config.registration_token.is_none() + { + error!( + "!! You have `allow_registration` enabled without a token configured in your config which means you are \ + allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n + If this is not the intended behaviour, please set a registration token with the `registration_token` config \ + option.\n + For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour \ + you want, please set the following config option to true: + `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`" + ); + return; + } - if config.allow_registration - && config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse - && config.registration_token.is_none() - { - warn!("Open registration is enabled via setting `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` and `allow_registration` to true without a registration token configured. You are expected to be aware of the risks now.\n - If this is not the desired behaviour, please set a registration token."); - } + if config.allow_registration + && config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse + && config.registration_token.is_none() + { + warn!( + "Open registration is enabled via setting \ + `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` and `allow_registration` to \ + true without a registration token configured. You are expected to be aware of the risks now.\n + If this is not the desired behaviour, please set a registration token." + ); + } - if config.allow_outgoing_presence && !config.allow_local_presence { - error!("Outgoing presence requires allowing local presence. Please enable \"allow_outgoing_presence\"."); - return; - } + if config.allow_outgoing_presence && !config.allow_local_presence { + error!("Outgoing presence requires allowing local presence. Please enable \"allow_outgoing_presence\"."); + return; + } - if config.allow_outgoing_presence { - warn!("! Outgoing federated presence is not spec compliant due to relying on PDUs and EDUs combined.\nOutgoing presence will not be very reliable due to this and any issues with federated outgoing presence are very likely attributed to this issue.\nIncoming presence and local presence are unaffected."); - } + if config.allow_outgoing_presence { + warn!( + "! Outgoing federated presence is not spec compliant due to relying on PDUs and EDUs combined.\nOutgoing \ + presence will not be very reliable due to this and any issues with federated outgoing presence are very \ + likely attributed to this issue.\nIncoming presence and local presence are unaffected." + ); + } - if config - .url_preview_domain_contains_allowlist - .contains(&"*".to_owned()) - { - warn!("All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". This opens up significant attack surface to your server. You are expected to be aware of the risks by doing this."); - } - if config - .url_preview_domain_explicit_allowlist - .contains(&"*".to_owned()) - { - warn!("All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". This opens up significant attack surface to your server. You are expected to be aware of the risks by doing this."); - } - if config - .url_preview_url_contains_allowlist - .contains(&"*".to_owned()) - { - warn!("All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This opens up significant attack surface to your server. You are expected to be aware of the risks by doing this."); - } + if config.url_preview_domain_contains_allowlist.contains(&"*".to_owned()) { + warn!( + "All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". \ + This opens up significant attack surface to your server. You are expected to be aware of the risks by \ + doing this." + ); + } + if config.url_preview_domain_explicit_allowlist.contains(&"*".to_owned()) { + warn!( + "All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". \ + This opens up significant attack surface to your server. You are expected to be aware of the risks by \ + doing this." + ); + } + if config.url_preview_url_contains_allowlist.contains(&"*".to_owned()) { + warn!( + "All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This \ + opens up significant attack surface to your server. You are expected to be aware of the risks by doing \ + this." + ); + } - /* end ad-hoc config validation/checks */ + /* end ad-hoc config validation/checks */ - info!("Starting server"); - if let Err(e) = run_server().await { - error!("Critical error starting server: {}", e); - }; + info!("Starting server"); + if let Err(e) = run_server().await { + error!("Critical error starting server: {}", e); + }; - // if server runs into critical error and shuts down, shut down the tracer provider if jaegar is used. - // awaiting run_server() is a blocking call so putting this after is fine, but not the other options above. - if config.allow_jaeger { - opentelemetry::global::shutdown_tracer_provider(); - } + // if server runs into critical error and shuts down, shut down the tracer + // provider if jaegar is used. awaiting run_server() is a blocking call so + // putting this after is fine, but not the other options above. + if config.allow_jaeger { + opentelemetry::global::shutdown_tracer_provider(); + } } async fn run_server() -> io::Result<()> { - let config = &services().globals.config; + let config = &services().globals.config; - let addrs = match &config.port.ports { - Left(port) => { - // Left is only 1 value, so make a vec with 1 value only - let port_vec = [port]; + let addrs = match &config.port.ports { + Left(port) => { + // Left is only 1 value, so make a vec with 1 value only + let port_vec = [port]; - port_vec - .iter() - .copied() - .map(|port| SocketAddr::from((config.address, *port))) - .collect::>() - } - Right(ports) => ports - .iter() - .copied() - .map(|port| SocketAddr::from((config.address, port))) - .collect::>(), - }; + port_vec.iter().copied().map(|port| SocketAddr::from((config.address, *port))).collect::>() + }, + Right(ports) => ports.iter().copied().map(|port| SocketAddr::from((config.address, port))).collect::>(), + }; - let x_requested_with = HeaderName::from_static("x-requested-with"); - let x_forwarded_for = HeaderName::from_static("x-forwarded-for"); + let x_requested_with = HeaderName::from_static("x-requested-with"); + let x_forwarded_for = HeaderName::from_static("x-forwarded-for"); - let middlewares = ServiceBuilder::new() - .sensitive_headers([header::AUTHORIZATION]) - .sensitive_request_headers([x_forwarded_for].into()) - .layer(axum::middleware::from_fn(spawn_task)) - .layer( - TraceLayer::new_for_http() - .make_span_with(|request: &http::Request<_>| { - let path = if let Some(path) = request.extensions().get::() { - path.as_str() - } else { - request.uri().path() - }; + let middlewares = ServiceBuilder::new() + .sensitive_headers([header::AUTHORIZATION]) + .sensitive_request_headers([x_forwarded_for].into()) + .layer(axum::middleware::from_fn(spawn_task)) + .layer( + TraceLayer::new_for_http() + .make_span_with(|request: &http::Request<_>| { + let path = if let Some(path) = request.extensions().get::() { + path.as_str() + } else { + request.uri().path() + }; - tracing::info_span!("http_request", %path) - }) - .on_failure(DefaultOnFailure::new().level(Level::INFO)), - ) - .layer(axum::middleware::from_fn(unrecognized_method)) - .layer( - CorsLayer::new() - .allow_origin(cors::Any) - .allow_methods([ - Method::GET, - Method::HEAD, - Method::POST, - Method::PUT, - Method::DELETE, - Method::OPTIONS, - ]) - .allow_headers([ - header::ORIGIN, - x_requested_with, - header::CONTENT_TYPE, - header::ACCEPT, - header::AUTHORIZATION, - ]) - .max_age(Duration::from_secs(86400)), - ) - .layer(DefaultBodyLimit::max( - config - .max_request_size - .try_into() - .expect("failed to convert max request size"), - )); + tracing::info_span!("http_request", %path) + }) + .on_failure(DefaultOnFailure::new().level(Level::INFO)), + ) + .layer(axum::middleware::from_fn(unrecognized_method)) + .layer( + CorsLayer::new() + .allow_origin(cors::Any) + .allow_methods([ + Method::GET, + Method::HEAD, + Method::POST, + Method::PUT, + Method::DELETE, + Method::OPTIONS, + ]) + .allow_headers([ + header::ORIGIN, + x_requested_with, + header::CONTENT_TYPE, + header::ACCEPT, + header::AUTHORIZATION, + ]) + .max_age(Duration::from_secs(86400)), + ) + .layer(DefaultBodyLimit::max( + config.max_request_size.try_into().expect("failed to convert max request size"), + )); - let app = if cfg!(feature = "zstd_compression") && config.zstd_compression { - debug!("zstd body compression is enabled"); - routes() - .layer(middlewares.compression()) - .into_make_service() - } else { - routes().layer(middlewares).into_make_service() - }; + let app = if cfg!(feature = "zstd_compression") && config.zstd_compression { + debug!("zstd body compression is enabled"); + routes().layer(middlewares.compression()).into_make_service() + } else { + routes().layer(middlewares).into_make_service() + }; - let handle = ServerHandle::new(); - let (tx, rx) = oneshot::channel::<()>(); + let handle = ServerHandle::new(); + let (tx, rx) = oneshot::channel::<()>(); - tokio::spawn(shutdown_signal(handle.clone(), tx)); + tokio::spawn(shutdown_signal(handle.clone(), tx)); - if let Some(path) = &config.unix_socket_path { - if path.exists() { - warn!( - "UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.", - path.display() - ); - tokio::fs::remove_file(&path).await?; - } + if let Some(path) = &config.unix_socket_path { + if path.exists() { + warn!( + "UNIX socket path {:#?} already exists (unclean shutdown?), attempting to remove it.", + path.display() + ); + tokio::fs::remove_file(&path).await?; + } - tokio::fs::create_dir_all(path.parent().unwrap()).await?; + tokio::fs::create_dir_all(path.parent().unwrap()).await?; - let socket_perms = config.unix_socket_perms.to_string(); - let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap(); + let socket_perms = config.unix_socket_perms.to_string(); + let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap(); - let listener = UnixListener::bind(path.clone())?; - tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms)) - .await - .unwrap(); - let socket = SocketIncoming::from_listener(listener); + let listener = UnixListener::bind(path.clone())?; + tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms)).await.unwrap(); + let socket = SocketIncoming::from_listener(listener); - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - info!("Listening at {:?}", path); - let server = Server::builder(socket).serve(app); - let graceful = server.with_graceful_shutdown(async { - rx.await.ok(); - }); + info!("Listening at {:?}", path); + let server = Server::builder(socket).serve(app); + let graceful = server.with_graceful_shutdown(async { + rx.await.ok(); + }); - if let Err(e) = graceful.await { - error!("Server error: {:?}", e); - } - } else { - match &config.tls { - Some(tls) => { - debug!( - "Using direct TLS. Certificate path {} and certificate private key path {}", - &tls.certs, &tls.key - ); - info!("Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit directly with TLS."); - let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; + if let Err(e) = graceful.await { + error!("Server error: {:?}", e); + } + } else { + match &config.tls { + Some(tls) => { + debug!( + "Using direct TLS. Certificate path {} and certificate private key path {}", + &tls.certs, &tls.key + ); + info!( + "Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit \ + directly with TLS." + ); + let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?; - if cfg!(feature = "axum_dual_protocol") { - info!( - "conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This will only take affect if `dual_protocol` is enabled in `[global.tls]`" - ); - } + if cfg!(feature = "axum_dual_protocol") { + info!( + "conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This \ + will only take affect if `dual_protocol` is enabled in `[global.tls]`" + ); + } - let mut join_set = JoinSet::new(); + let mut join_set = JoinSet::new(); - if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { - #[cfg(feature = "axum_dual_protocol")] - for addr in &addrs { - join_set.spawn( - axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()) - .set_upgrade(false) - .handle(handle.clone()) - .serve(app.clone()), - ); - } - } else { - for addr in &addrs { - join_set.spawn( - bind_rustls(*addr, conf.clone()) - .handle(handle.clone()) - .serve(app.clone()), - ); - } - } + if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { + #[cfg(feature = "axum_dual_protocol")] + for addr in &addrs { + join_set.spawn( + axum_server_dual_protocol::bind_dual_protocol(*addr, conf.clone()) + .set_upgrade(false) + .handle(handle.clone()) + .serve(app.clone()), + ); + } + } else { + for addr in &addrs { + join_set.spawn(bind_rustls(*addr, conf.clone()).handle(handle.clone()).serve(app.clone())); + } + } - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { - warn!( - "Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too (insecure!)", - addrs, &tls.certs - ); - } else { - info!( - "Listening on {:?} with TLS certificate {}", - addrs, &tls.certs - ); - } + if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol { + warn!( + "Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too \ + (insecure!)", + addrs, &tls.certs + ); + } else { + info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs); + } - join_set.join_next().await; - } - None => { - let mut join_set = JoinSet::new(); - for addr in &addrs { - join_set.spawn(bind(*addr).handle(handle.clone()).serve(app.clone())); - } + join_set.join_next().await; + }, + None => { + let mut join_set = JoinSet::new(); + for addr in &addrs { + join_set.spawn(bind(*addr).handle(handle.clone()).serve(app.clone())); + } - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Ready]); - info!("Listening on {:?}", addrs); - join_set.join_next().await; - } - } - } + info!("Listening on {:?}", addrs); + join_set.join_next().await; + }, + } + } - Ok(()) + Ok(()) } async fn spawn_task( - req: axum::http::Request, - next: axum::middleware::Next, + req: axum::http::Request, next: axum::middleware::Next, ) -> std::result::Result { - if services().globals.shutdown.load(atomic::Ordering::Relaxed) { - return Err(StatusCode::SERVICE_UNAVAILABLE); - } - tokio::spawn(next.run(req)) - .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) + if services().globals.shutdown.load(atomic::Ordering::Relaxed) { + return Err(StatusCode::SERVICE_UNAVAILABLE); + } + tokio::spawn(next.run(req)).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR) } async fn unrecognized_method( - req: axum::http::Request, - next: axum::middleware::Next, + req: axum::http::Request, next: axum::middleware::Next, ) -> std::result::Result { - let method = req.method().clone(); - let uri = req.uri().clone(); - let inner = next.run(req).await; - if inner.status() == axum::http::StatusCode::METHOD_NOT_ALLOWED { - warn!("Method not allowed: {method} {uri}"); - return Ok(RumaResponse(UiaaResponse::MatrixError(RumaError { - body: ErrorBody::Standard { - kind: ErrorKind::Unrecognized, - message: "M_UNRECOGNIZED: Unrecognized request".to_owned(), - }, - status_code: StatusCode::METHOD_NOT_ALLOWED, - })) - .into_response()); - } - Ok(inner) + let method = req.method().clone(); + let uri = req.uri().clone(); + let inner = next.run(req).await; + if inner.status() == axum::http::StatusCode::METHOD_NOT_ALLOWED { + warn!("Method not allowed: {method} {uri}"); + return Ok(RumaResponse(UiaaResponse::MatrixError(RumaError { + body: ErrorBody::Standard { + kind: ErrorKind::Unrecognized, + message: "M_UNRECOGNIZED: Unrecognized request".to_owned(), + }, + status_code: StatusCode::METHOD_NOT_ALLOWED, + })) + .into_response()); + } + Ok(inner) } fn routes() -> Router { - Router::new() - .ruma_route(client_server::get_supported_versions_route) - .ruma_route(client_server::get_register_available_route) - .ruma_route(client_server::register_route) - .ruma_route(client_server::get_login_types_route) - .ruma_route(client_server::login_route) - .ruma_route(client_server::whoami_route) - .ruma_route(client_server::logout_route) - .ruma_route(client_server::logout_all_route) - .ruma_route(client_server::change_password_route) - .ruma_route(client_server::deactivate_route) - .ruma_route(client_server::third_party_route) - .ruma_route(client_server::request_3pid_management_token_via_email_route) - .ruma_route(client_server::request_3pid_management_token_via_msisdn_route) - .ruma_route(client_server::get_capabilities_route) - .ruma_route(client_server::get_pushrules_all_route) - .ruma_route(client_server::set_pushrule_route) - .ruma_route(client_server::get_pushrule_route) - .ruma_route(client_server::set_pushrule_enabled_route) - .ruma_route(client_server::get_pushrule_enabled_route) - .ruma_route(client_server::get_pushrule_actions_route) - .ruma_route(client_server::set_pushrule_actions_route) - .ruma_route(client_server::delete_pushrule_route) - .ruma_route(client_server::get_room_event_route) - .ruma_route(client_server::get_room_aliases_route) - .ruma_route(client_server::get_filter_route) - .ruma_route(client_server::create_filter_route) - .ruma_route(client_server::set_global_account_data_route) - .ruma_route(client_server::set_room_account_data_route) - .ruma_route(client_server::get_global_account_data_route) - .ruma_route(client_server::get_room_account_data_route) - .ruma_route(client_server::set_displayname_route) - .ruma_route(client_server::get_displayname_route) - .ruma_route(client_server::set_avatar_url_route) - .ruma_route(client_server::get_avatar_url_route) - .ruma_route(client_server::get_profile_route) - .ruma_route(client_server::set_presence_route) - .ruma_route(client_server::get_presence_route) - .ruma_route(client_server::upload_keys_route) - .ruma_route(client_server::get_keys_route) - .ruma_route(client_server::claim_keys_route) - .ruma_route(client_server::create_backup_version_route) - .ruma_route(client_server::update_backup_version_route) - .ruma_route(client_server::delete_backup_version_route) - .ruma_route(client_server::get_latest_backup_info_route) - .ruma_route(client_server::get_backup_info_route) - .ruma_route(client_server::add_backup_keys_route) - .ruma_route(client_server::add_backup_keys_for_room_route) - .ruma_route(client_server::add_backup_keys_for_session_route) - .ruma_route(client_server::delete_backup_keys_for_room_route) - .ruma_route(client_server::delete_backup_keys_for_session_route) - .ruma_route(client_server::delete_backup_keys_route) - .ruma_route(client_server::get_backup_keys_for_room_route) - .ruma_route(client_server::get_backup_keys_for_session_route) - .ruma_route(client_server::get_backup_keys_route) - .ruma_route(client_server::set_read_marker_route) - .ruma_route(client_server::create_receipt_route) - .ruma_route(client_server::create_typing_event_route) - .ruma_route(client_server::create_room_route) - .ruma_route(client_server::redact_event_route) - .ruma_route(client_server::report_event_route) - .ruma_route(client_server::create_alias_route) - .ruma_route(client_server::delete_alias_route) - .ruma_route(client_server::get_alias_route) - .ruma_route(client_server::join_room_by_id_route) - .ruma_route(client_server::join_room_by_id_or_alias_route) - .ruma_route(client_server::joined_members_route) - .ruma_route(client_server::leave_room_route) - .ruma_route(client_server::forget_room_route) - .ruma_route(client_server::joined_rooms_route) - .ruma_route(client_server::kick_user_route) - .ruma_route(client_server::ban_user_route) - .ruma_route(client_server::unban_user_route) - .ruma_route(client_server::invite_user_route) - .ruma_route(client_server::set_room_visibility_route) - .ruma_route(client_server::get_room_visibility_route) - .ruma_route(client_server::get_public_rooms_route) - .ruma_route(client_server::get_public_rooms_filtered_route) - .ruma_route(client_server::search_users_route) - .ruma_route(client_server::get_member_events_route) - .ruma_route(client_server::get_protocols_route) - .ruma_route(client_server::send_message_event_route) - .ruma_route(client_server::send_state_event_for_key_route) - .ruma_route(client_server::get_state_events_route) - .ruma_route(client_server::get_state_events_for_key_route) - // Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes - // share one Ruma request / response type pair with {get,send}_state_event_for_key_route - .route( - "/_matrix/client/r0/rooms/:room_id/state/:event_type", - get(client_server::get_state_events_for_empty_key_route) - .put(client_server::send_state_event_for_empty_key_route), - ) - .route( - "/_matrix/client/v3/rooms/:room_id/state/:event_type", - get(client_server::get_state_events_for_empty_key_route) - .put(client_server::send_state_event_for_empty_key_route), - ) - // These two endpoints allow trailing slashes - .route( - "/_matrix/client/r0/rooms/:room_id/state/:event_type/", - get(client_server::get_state_events_for_empty_key_route) - .put(client_server::send_state_event_for_empty_key_route), - ) - .route( - "/_matrix/client/v3/rooms/:room_id/state/:event_type/", - get(client_server::get_state_events_for_empty_key_route) - .put(client_server::send_state_event_for_empty_key_route), - ) - .ruma_route(client_server::sync_events_route) - .ruma_route(client_server::sync_events_v4_route) - .ruma_route(client_server::get_context_route) - .ruma_route(client_server::get_message_events_route) - .ruma_route(client_server::search_events_route) - .ruma_route(client_server::turn_server_route) - .ruma_route(client_server::send_event_to_device_route) - .ruma_route(client_server::get_media_config_route) - .ruma_route(client_server::get_media_preview_route) - .ruma_route(client_server::create_content_route) - .ruma_route(client_server::get_content_route) - .ruma_route(client_server::get_content_as_filename_route) - .ruma_route(client_server::get_content_thumbnail_route) - .ruma_route(client_server::get_devices_route) - .ruma_route(client_server::get_device_route) - .ruma_route(client_server::update_device_route) - .ruma_route(client_server::delete_device_route) - .ruma_route(client_server::delete_devices_route) - .ruma_route(client_server::get_tags_route) - .ruma_route(client_server::update_tag_route) - .ruma_route(client_server::delete_tag_route) - .ruma_route(client_server::upload_signing_keys_route) - .ruma_route(client_server::upload_signatures_route) - .ruma_route(client_server::get_key_changes_route) - .ruma_route(client_server::get_pushers_route) - .ruma_route(client_server::set_pushers_route) - // .ruma_route(client_server::third_party_route) - .ruma_route(client_server::upgrade_room_route) - .ruma_route(client_server::get_threads_route) - .ruma_route(client_server::get_relating_events_with_rel_type_and_event_type_route) - .ruma_route(client_server::get_relating_events_with_rel_type_route) - .ruma_route(client_server::get_relating_events_route) - .ruma_route(client_server::get_hierarchy_route) - .ruma_route(server_server::get_server_version_route) - .route( - "/_matrix/key/v2/server", - get(server_server::get_server_keys_route), - ) - .route( - "/_matrix/key/v2/server/:key_id", - get(server_server::get_server_keys_deprecated_route), - ) - .ruma_route(server_server::get_public_rooms_route) - .ruma_route(server_server::get_public_rooms_filtered_route) - .ruma_route(server_server::send_transaction_message_route) - .ruma_route(server_server::get_event_route) - .ruma_route(server_server::get_backfill_route) - .ruma_route(server_server::get_missing_events_route) - .ruma_route(server_server::get_event_authorization_route) - .ruma_route(server_server::get_room_state_route) - .ruma_route(server_server::get_room_state_ids_route) - .ruma_route(server_server::create_join_event_template_route) - .ruma_route(server_server::create_join_event_v1_route) - .ruma_route(server_server::create_join_event_v2_route) - .ruma_route(server_server::create_invite_route) - .ruma_route(server_server::get_devices_route) - .ruma_route(server_server::get_room_information_route) - .ruma_route(server_server::get_profile_information_route) - .ruma_route(server_server::get_keys_route) - .ruma_route(server_server::claim_keys_route) - .route( - "/_matrix/client/r0/rooms/:room_id/initialSync", - get(initial_sync), - ) - .route( - "/_matrix/client/v3/rooms/:room_id/initialSync", - get(initial_sync), - ) - .route( - "/client/server.json", - get(client_server::syncv3_client_server_json), - ) - .route( - "/.well-known/matrix/client", - get(client_server::well_known_client_route), - ) - .route( - "/.well-known/matrix/server", - get(server_server::well_known_server_route), - ) - .route("/", get(it_works)) - .fallback(not_found) + Router::new() + .ruma_route(client_server::get_supported_versions_route) + .ruma_route(client_server::get_register_available_route) + .ruma_route(client_server::register_route) + .ruma_route(client_server::get_login_types_route) + .ruma_route(client_server::login_route) + .ruma_route(client_server::whoami_route) + .ruma_route(client_server::logout_route) + .ruma_route(client_server::logout_all_route) + .ruma_route(client_server::change_password_route) + .ruma_route(client_server::deactivate_route) + .ruma_route(client_server::third_party_route) + .ruma_route(client_server::request_3pid_management_token_via_email_route) + .ruma_route(client_server::request_3pid_management_token_via_msisdn_route) + .ruma_route(client_server::get_capabilities_route) + .ruma_route(client_server::get_pushrules_all_route) + .ruma_route(client_server::set_pushrule_route) + .ruma_route(client_server::get_pushrule_route) + .ruma_route(client_server::set_pushrule_enabled_route) + .ruma_route(client_server::get_pushrule_enabled_route) + .ruma_route(client_server::get_pushrule_actions_route) + .ruma_route(client_server::set_pushrule_actions_route) + .ruma_route(client_server::delete_pushrule_route) + .ruma_route(client_server::get_room_event_route) + .ruma_route(client_server::get_room_aliases_route) + .ruma_route(client_server::get_filter_route) + .ruma_route(client_server::create_filter_route) + .ruma_route(client_server::set_global_account_data_route) + .ruma_route(client_server::set_room_account_data_route) + .ruma_route(client_server::get_global_account_data_route) + .ruma_route(client_server::get_room_account_data_route) + .ruma_route(client_server::set_displayname_route) + .ruma_route(client_server::get_displayname_route) + .ruma_route(client_server::set_avatar_url_route) + .ruma_route(client_server::get_avatar_url_route) + .ruma_route(client_server::get_profile_route) + .ruma_route(client_server::set_presence_route) + .ruma_route(client_server::get_presence_route) + .ruma_route(client_server::upload_keys_route) + .ruma_route(client_server::get_keys_route) + .ruma_route(client_server::claim_keys_route) + .ruma_route(client_server::create_backup_version_route) + .ruma_route(client_server::update_backup_version_route) + .ruma_route(client_server::delete_backup_version_route) + .ruma_route(client_server::get_latest_backup_info_route) + .ruma_route(client_server::get_backup_info_route) + .ruma_route(client_server::add_backup_keys_route) + .ruma_route(client_server::add_backup_keys_for_room_route) + .ruma_route(client_server::add_backup_keys_for_session_route) + .ruma_route(client_server::delete_backup_keys_for_room_route) + .ruma_route(client_server::delete_backup_keys_for_session_route) + .ruma_route(client_server::delete_backup_keys_route) + .ruma_route(client_server::get_backup_keys_for_room_route) + .ruma_route(client_server::get_backup_keys_for_session_route) + .ruma_route(client_server::get_backup_keys_route) + .ruma_route(client_server::set_read_marker_route) + .ruma_route(client_server::create_receipt_route) + .ruma_route(client_server::create_typing_event_route) + .ruma_route(client_server::create_room_route) + .ruma_route(client_server::redact_event_route) + .ruma_route(client_server::report_event_route) + .ruma_route(client_server::create_alias_route) + .ruma_route(client_server::delete_alias_route) + .ruma_route(client_server::get_alias_route) + .ruma_route(client_server::join_room_by_id_route) + .ruma_route(client_server::join_room_by_id_or_alias_route) + .ruma_route(client_server::joined_members_route) + .ruma_route(client_server::leave_room_route) + .ruma_route(client_server::forget_room_route) + .ruma_route(client_server::joined_rooms_route) + .ruma_route(client_server::kick_user_route) + .ruma_route(client_server::ban_user_route) + .ruma_route(client_server::unban_user_route) + .ruma_route(client_server::invite_user_route) + .ruma_route(client_server::set_room_visibility_route) + .ruma_route(client_server::get_room_visibility_route) + .ruma_route(client_server::get_public_rooms_route) + .ruma_route(client_server::get_public_rooms_filtered_route) + .ruma_route(client_server::search_users_route) + .ruma_route(client_server::get_member_events_route) + .ruma_route(client_server::get_protocols_route) + .ruma_route(client_server::send_message_event_route) + .ruma_route(client_server::send_state_event_for_key_route) + .ruma_route(client_server::get_state_events_route) + .ruma_route(client_server::get_state_events_for_key_route) + // Ruma doesn't have support for multiple paths for a single endpoint yet, and these routes + // share one Ruma request / response type pair with {get,send}_state_event_for_key_route + .route( + "/_matrix/client/r0/rooms/:room_id/state/:event_type", + get(client_server::get_state_events_for_empty_key_route) + .put(client_server::send_state_event_for_empty_key_route), + ) + .route( + "/_matrix/client/v3/rooms/:room_id/state/:event_type", + get(client_server::get_state_events_for_empty_key_route) + .put(client_server::send_state_event_for_empty_key_route), + ) + // These two endpoints allow trailing slashes + .route( + "/_matrix/client/r0/rooms/:room_id/state/:event_type/", + get(client_server::get_state_events_for_empty_key_route) + .put(client_server::send_state_event_for_empty_key_route), + ) + .route( + "/_matrix/client/v3/rooms/:room_id/state/:event_type/", + get(client_server::get_state_events_for_empty_key_route) + .put(client_server::send_state_event_for_empty_key_route), + ) + .ruma_route(client_server::sync_events_route) + .ruma_route(client_server::sync_events_v4_route) + .ruma_route(client_server::get_context_route) + .ruma_route(client_server::get_message_events_route) + .ruma_route(client_server::search_events_route) + .ruma_route(client_server::turn_server_route) + .ruma_route(client_server::send_event_to_device_route) + .ruma_route(client_server::get_media_config_route) + .ruma_route(client_server::get_media_preview_route) + .ruma_route(client_server::create_content_route) + .ruma_route(client_server::get_content_route) + .ruma_route(client_server::get_content_as_filename_route) + .ruma_route(client_server::get_content_thumbnail_route) + .ruma_route(client_server::get_devices_route) + .ruma_route(client_server::get_device_route) + .ruma_route(client_server::update_device_route) + .ruma_route(client_server::delete_device_route) + .ruma_route(client_server::delete_devices_route) + .ruma_route(client_server::get_tags_route) + .ruma_route(client_server::update_tag_route) + .ruma_route(client_server::delete_tag_route) + .ruma_route(client_server::upload_signing_keys_route) + .ruma_route(client_server::upload_signatures_route) + .ruma_route(client_server::get_key_changes_route) + .ruma_route(client_server::get_pushers_route) + .ruma_route(client_server::set_pushers_route) + // .ruma_route(client_server::third_party_route) + .ruma_route(client_server::upgrade_room_route) + .ruma_route(client_server::get_threads_route) + .ruma_route(client_server::get_relating_events_with_rel_type_and_event_type_route) + .ruma_route(client_server::get_relating_events_with_rel_type_route) + .ruma_route(client_server::get_relating_events_route) + .ruma_route(client_server::get_hierarchy_route) + .ruma_route(server_server::get_server_version_route) + .route("/_matrix/key/v2/server", get(server_server::get_server_keys_route)) + .route( + "/_matrix/key/v2/server/:key_id", + get(server_server::get_server_keys_deprecated_route), + ) + .ruma_route(server_server::get_public_rooms_route) + .ruma_route(server_server::get_public_rooms_filtered_route) + .ruma_route(server_server::send_transaction_message_route) + .ruma_route(server_server::get_event_route) + .ruma_route(server_server::get_backfill_route) + .ruma_route(server_server::get_missing_events_route) + .ruma_route(server_server::get_event_authorization_route) + .ruma_route(server_server::get_room_state_route) + .ruma_route(server_server::get_room_state_ids_route) + .ruma_route(server_server::create_join_event_template_route) + .ruma_route(server_server::create_join_event_v1_route) + .ruma_route(server_server::create_join_event_v2_route) + .ruma_route(server_server::create_invite_route) + .ruma_route(server_server::get_devices_route) + .ruma_route(server_server::get_room_information_route) + .ruma_route(server_server::get_profile_information_route) + .ruma_route(server_server::get_keys_route) + .ruma_route(server_server::claim_keys_route) + .route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync)) + .route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync)) + .route("/client/server.json", get(client_server::syncv3_client_server_json)) + .route("/.well-known/matrix/client", get(client_server::well_known_client_route)) + .route("/.well-known/matrix/server", get(server_server::well_known_server_route)) + .route("/", get(it_works)) + .fallback(not_found) } async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; + let ctrl_c = async { + signal::ctrl_c().await.expect("failed to install Ctrl+C handler"); + }; - #[cfg(unix)] - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; + #[cfg(unix)] + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; - #[cfg(not(unix))] - let terminate = std::future::pending::<()>(); + #[cfg(not(unix))] + let terminate = std::future::pending::<()>(); - let sig: &str; + let sig: &str; - tokio::select! { - _ = ctrl_c => { sig = "Ctrl+C"; }, - _ = terminate => { sig = "SIGTERM"; }, - } + tokio::select! { + _ = ctrl_c => { sig = "Ctrl+C"; }, + _ = terminate => { sig = "SIGTERM"; }, + } - warn!("Received {}, shutting down...", sig); - let shutdown_time_elapsed = tokio::time::Instant::now(); - handle.graceful_shutdown(Some(Duration::from_secs(180))); + warn!("Received {}, shutting down...", sig); + let shutdown_time_elapsed = tokio::time::Instant::now(); + handle.graceful_shutdown(Some(Duration::from_secs(180))); - services().globals.shutdown(); + services().globals.shutdown(); - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]); + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]); - tx.send(()).expect("failed sending shutdown transaction to oneshot channel (this is unlikely a conduwuit bug and more so your system may not be in an okay/ideal state.)"); + tx.send(()).expect( + "failed sending shutdown transaction to oneshot channel (this is unlikely a conduwuit bug and more so your \ + system may not be in an okay/ideal state.)", + ); - if shutdown_time_elapsed.elapsed() >= Duration::from_secs(60) && cfg!(feature = "systemd") { - warn!("Still shutting down after 60 seconds since receiving shutdown signal, asking systemd for more time (+120 seconds). Remaining connections: {}", handle.connection_count()); + if shutdown_time_elapsed.elapsed() >= Duration::from_secs(60) && cfg!(feature = "systemd") { + warn!( + "Still shutting down after 60 seconds since receiving shutdown signal, asking systemd for more time (+120 \ + seconds). Remaining connections: {}", + handle.connection_count() + ); - #[cfg(feature = "systemd")] - let _ = sd_notify::notify(true, &[sd_notify::NotifyState::ExtendTimeoutUsec(120)]); - } + #[cfg(feature = "systemd")] + let _ = sd_notify::notify(true, &[sd_notify::NotifyState::ExtendTimeoutUsec(120)]); + } - warn!( - "Time took to shutdown: {:?} seconds", - shutdown_time_elapsed.elapsed() - ); + warn!("Time took to shutdown: {:?} seconds", shutdown_time_elapsed.elapsed()); - Ok(()) + Ok(()) } async fn not_found(uri: Uri) -> impl IntoResponse { - warn!("Not found: {uri}"); - Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") + warn!("Not found: {uri}"); + Error::BadRequest(ErrorKind::Unrecognized, "Unrecognized request") } async fn initial_sync(_uri: Uri) -> impl IntoResponse { - Error::BadRequest( - ErrorKind::GuestAccessForbidden, - "Guest access not implemented", - ) + Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented") } -async fn it_works() -> &'static str { - "hewwo from conduwuit woof!" -} +async fn it_works() -> &'static str { "hewwo from conduwuit woof!" } trait RouterExt { - fn ruma_route(self, handler: H) -> Self - where - H: RumaHandler, - T: 'static; + fn ruma_route(self, handler: H) -> Self + where + H: RumaHandler, + T: 'static; } impl RouterExt for Router { - fn ruma_route(self, handler: H) -> Self - where - H: RumaHandler, - T: 'static, - { - handler.add_to_router(self) - } + fn ruma_route(self, handler: H) -> Self + where + H: RumaHandler, + T: 'static, + { + handler.add_to_router(self) + } } pub trait RumaHandler { - // Can't transform to a handler without boxing or relying on the nightly-only - // impl-trait-in-traits feature. Moving a small amount of extra logic into the trait - // allows bypassing both. - fn add_to_router(self, router: Router) -> Router; + // Can't transform to a handler without boxing or relying on the nightly-only + // impl-trait-in-traits feature. Moving a small amount of extra logic into the + // trait allows bypassing both. + fn add_to_router(self, router: Router) -> Router; } macro_rules! impl_ruma_handler { @@ -820,33 +806,33 @@ impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7); impl_ruma_handler!(T1, T2, T3, T4, T5, T6, T7, T8); fn method_to_filter(method: Method) -> MethodFilter { - match method { - Method::DELETE => MethodFilter::DELETE, - Method::GET => MethodFilter::GET, - Method::HEAD => MethodFilter::HEAD, - Method::OPTIONS => MethodFilter::OPTIONS, - Method::PATCH => MethodFilter::PATCH, - Method::POST => MethodFilter::POST, - Method::PUT => MethodFilter::PUT, - Method::TRACE => MethodFilter::TRACE, - m => panic!("Unsupported HTTP method: {m:?}"), - } + match method { + Method::DELETE => MethodFilter::DELETE, + Method::GET => MethodFilter::GET, + Method::HEAD => MethodFilter::HEAD, + Method::OPTIONS => MethodFilter::OPTIONS, + Method::PATCH => MethodFilter::PATCH, + Method::POST => MethodFilter::POST, + Method::PUT => MethodFilter::PUT, + Method::TRACE => MethodFilter::TRACE, + m => panic!("Unsupported HTTP method: {m:?}"), + } } #[cfg(unix)] #[tracing::instrument(err)] fn maximize_fd_limit() -> Result<(), nix::errno::Errno> { - use nix::sys::resource::{getrlimit, setrlimit, Resource}; + use nix::sys::resource::{getrlimit, setrlimit, Resource}; - let res = Resource::RLIMIT_NOFILE; + let res = Resource::RLIMIT_NOFILE; - let (soft_limit, hard_limit) = getrlimit(res)?; + let (soft_limit, hard_limit) = getrlimit(res)?; - debug!("Current nofile soft limit: {soft_limit}"); + debug!("Current nofile soft limit: {soft_limit}"); - setrlimit(res, hard_limit, hard_limit)?; + setrlimit(res, hard_limit, hard_limit)?; - debug!("Increased nofile soft limit to {hard_limit}"); + debug!("Increased nofile soft limit to {hard_limit}"); - Ok(()) + Ok(()) } diff --git a/src/service/account_data/data.rs b/src/service/account_data/data.rs index c7c92981..492c500c 100644 --- a/src/service/account_data/data.rs +++ b/src/service/account_data/data.rs @@ -1,35 +1,28 @@ use std::collections::HashMap; -use crate::Result; use ruma::{ - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, + events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, + serde::Raw, + RoomId, UserId, }; +use crate::Result; + pub trait Data: Send + Sync { - /// Places one event in the account data of the user and removes the previous entry. - fn update( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - event_type: RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()>; + /// Places one event in the account data of the user and removes the + /// previous entry. + fn update( + &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, + data: &serde_json::Value, + ) -> Result<()>; - /// Searches the account data for a specific kind. - fn get( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - kind: RoomAccountDataEventType, - ) -> Result>>; + /// Searches the account data for a specific kind. + fn get( + &self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType, + ) -> Result>>; - /// Returns all changes to the account data that happened after `since`. - fn changes_since( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - since: u64, - ) -> Result>>; + /// Returns all changes to the account data that happened after `since`. + fn changes_since( + &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, + ) -> Result>>; } diff --git a/src/service/account_data/mod.rs b/src/service/account_data/mod.rs index bfdbd5f5..6acfbef4 100644 --- a/src/service/account_data/mod.rs +++ b/src/service/account_data/mod.rs @@ -1,53 +1,44 @@ mod data; -pub(crate) use data::Data; - -use ruma::{ - events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, - serde::Raw, - RoomId, UserId, -}; - use std::collections::HashMap; +pub(crate) use data::Data; +use ruma::{ + events::{AnyEphemeralRoomEvent, RoomAccountDataEventType}, + serde::Raw, + RoomId, UserId, +}; + use crate::Result; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - /// Places one event in the account data of the user and removes the previous entry. - #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] - pub fn update( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - event_type: RoomAccountDataEventType, - data: &serde_json::Value, - ) -> Result<()> { - self.db.update(room_id, user_id, event_type, data) - } + /// Places one event in the account data of the user and removes the + /// previous entry. + #[tracing::instrument(skip(self, room_id, user_id, event_type, data))] + pub fn update( + &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, + data: &serde_json::Value, + ) -> Result<()> { + self.db.update(room_id, user_id, event_type, data) + } - /// Searches the account data for a specific kind. - #[tracing::instrument(skip(self, room_id, user_id, event_type))] - pub fn get( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - event_type: RoomAccountDataEventType, - ) -> Result>> { - self.db.get(room_id, user_id, event_type) - } + /// Searches the account data for a specific kind. + #[tracing::instrument(skip(self, room_id, user_id, event_type))] + pub fn get( + &self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType, + ) -> Result>> { + self.db.get(room_id, user_id, event_type) + } - /// Returns all changes to the account data that happened after `since`. - #[tracing::instrument(skip(self, room_id, user_id, since))] - pub fn changes_since( - &self, - room_id: Option<&RoomId>, - user_id: &UserId, - since: u64, - ) -> Result>> { - self.db.changes_since(room_id, user_id, since) - } + /// Returns all changes to the account data that happened after `since`. + #[tracing::instrument(skip(self, room_id, user_id, since))] + pub fn changes_since( + &self, room_id: Option<&RoomId>, user_id: &UserId, since: u64, + ) -> Result>> { + self.db.changes_since(room_id, user_id, since) + } } diff --git a/src/service/admin/mod.rs b/src/service/admin/mod.rs index 1cb08bb2..10641c34 100644 --- a/src/service/admin/mod.rs +++ b/src/service/admin/mod.rs @@ -1,46 +1,44 @@ use std::{ - collections::BTreeMap, - sync::{Arc, RwLock}, - time::Instant, + collections::BTreeMap, + fmt::Write as _, + sync::{Arc, RwLock}, + time::Instant, }; -use std::fmt::Write as _; - use clap::{Parser, Subcommand}; use regex::Regex; use ruma::{ - api::{appservice::Registration, client::error::ErrorKind}, - events::{ - relation::InReplyTo, - room::{ - canonical_alias::RoomCanonicalAliasEventContent, - create::RoomCreateEventContent, - guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - join_rules::{JoinRule, RoomJoinRulesEventContent}, - member::{MembershipState, RoomMemberEventContent}, - message::{Relation::Reply, RoomMessageEventContent}, - name::RoomNameEventContent, - power_levels::RoomPowerLevelsEventContent, - topic::RoomTopicEventContent, - }, - TimelineEventType, - }, - EventId, MxcUri, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, - RoomOrAliasId, RoomVersionId, ServerName, UserId, + api::{appservice::Registration, client::error::ErrorKind}, + events::{ + relation::InReplyTo, + room::{ + canonical_alias::RoomCanonicalAliasEventContent, + create::RoomCreateEventContent, + guest_access::{GuestAccess, RoomGuestAccessEventContent}, + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + join_rules::{JoinRule, RoomJoinRulesEventContent}, + member::{MembershipState, RoomMemberEventContent}, + message::{Relation::Reply, RoomMessageEventContent}, + name::RoomNameEventContent, + power_levels::RoomPowerLevelsEventContent, + topic::RoomTopicEventContent, + }, + TimelineEventType, + }, + EventId, MxcUri, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, RoomVersionId, + ServerName, UserId, }; use serde_json::value::to_raw_value; use tokio::sync::{mpsc, Mutex}; use tracing::{debug, error, info, warn}; -use crate::{ - api::client_server::{get_alias_helper, leave_all_rooms, leave_room, AUTO_GEN_PASSWORD_LENGTH}, - services, - utils::{self, HtmlEscape}, - Error, PduEvent, Result, -}; - use super::pdu::PduBuilder; +use crate::{ + api::client_server::{get_alias_helper, leave_all_rooms, leave_room, AUTO_GEN_PASSWORD_LENGTH}, + services, + utils::{self, HtmlEscape}, + Error, PduEvent, Result, +}; const PAGE_SIZE: usize = 100; @@ -48,1113 +46,1136 @@ const PAGE_SIZE: usize = 100; #[derive(Parser)] #[command(name = "@conduit:server.name:", version = env!("CARGO_PKG_VERSION"))] enum AdminCommand { - #[command(subcommand)] - /// - Commands for managing appservices - Appservices(AppserviceCommand), + #[command(subcommand)] + /// - Commands for managing appservices + Appservices(AppserviceCommand), - #[command(subcommand)] - /// - Commands for managing local users - Users(UserCommand), + #[command(subcommand)] + /// - Commands for managing local users + Users(UserCommand), - #[command(subcommand)] - /// - Commands for managing rooms - Rooms(RoomCommand), + #[command(subcommand)] + /// - Commands for managing rooms + Rooms(RoomCommand), - #[command(subcommand)] - /// - Commands for managing federation - Federation(FederationCommand), + #[command(subcommand)] + /// - Commands for managing federation + Federation(FederationCommand), - #[command(subcommand)] - /// - Commands for managing the server - Server(ServerCommand), + #[command(subcommand)] + /// - Commands for managing the server + Server(ServerCommand), - #[command(subcommand)] - /// - Commands for managing media - Media(MediaCommand), + #[command(subcommand)] + /// - Commands for managing media + Media(MediaCommand), - #[command(subcommand)] - // TODO: should i split out debug commands to a separate thing? the - // debug commands seem like they could fit in the other categories fine - // this is more like a "miscellaneous" category than a debug one - /// - Commands for debugging things - Debug(DebugCommand), + #[command(subcommand)] + // TODO: should i split out debug commands to a separate thing? the + // debug commands seem like they could fit in the other categories fine + // this is more like a "miscellaneous" category than a debug one + /// - Commands for debugging things + Debug(DebugCommand), } #[cfg_attr(test, derive(Debug))] #[derive(Subcommand)] enum MediaCommand { - /// - Deletes a single media file from our database and on the filesystem via a single MXC URL - Delete { - /// The MXC URL to delete - #[arg(long)] - mxc: Option>, + /// - Deletes a single media file from our database and on the filesystem + /// via a single MXC URL + Delete { + /// The MXC URL to delete + #[arg(long)] + mxc: Option>, - /// - The message event ID which contains the media and thumbnail MXC URLs - #[arg(long)] - event_id: Option>, - }, + /// - The message event ID which contains the media and thumbnail MXC + /// URLs + #[arg(long)] + event_id: Option>, + }, - /// - Deletes a codeblock list of MXC URLs from our database and on the filesystem - DeleteList, + /// - Deletes a codeblock list of MXC URLs from our database and on the + /// filesystem + DeleteList, - /// - Deletes all remote media in the last X amount of time using filesystem metadata first created at date. - DeletePastRemoteMedia { - /// - The duration (at or after), e.g. "5m" to delete all media in the past 5 minutes - duration: String, - }, + /// - Deletes all remote media in the last X amount of time using filesystem + /// metadata first created at date. + DeletePastRemoteMedia { + /// - The duration (at or after), e.g. "5m" to delete all media in the + /// past 5 minutes + duration: String, + }, } #[cfg_attr(test, derive(Debug))] #[derive(Subcommand)] enum AppserviceCommand { - /// - Register an appservice using its registration YAML - /// - /// This command needs a YAML generated by an appservice (such as a bridge), - /// which must be provided in a Markdown code block below the command. - /// - /// Registering a new bridge using the ID of an existing bridge will replace - /// the old one. - Register, + /// - Register an appservice using its registration YAML + /// + /// This command needs a YAML generated by an appservice (such as a bridge), + /// which must be provided in a Markdown code block below the command. + /// + /// Registering a new bridge using the ID of an existing bridge will replace + /// the old one. + Register, - /// - Unregister an appservice using its ID - /// - /// You can find the ID using the `list-appservices` command. - Unregister { - /// The appservice to unregister - appservice_identifier: String, - }, + /// - Unregister an appservice using its ID + /// + /// You can find the ID using the `list-appservices` command. + Unregister { + /// The appservice to unregister + appservice_identifier: String, + }, - /// - Show an appservice's config using its ID - /// - /// You can find the ID using the `list-appservices` command. - Show { - /// The appservice to show - appservice_identifier: String, - }, + /// - Show an appservice's config using its ID + /// + /// You can find the ID using the `list-appservices` command. + Show { + /// The appservice to show + appservice_identifier: String, + }, - /// - List all the currently registered appservices - List, + /// - List all the currently registered appservices + List, } #[cfg_attr(test, derive(Debug))] #[derive(Subcommand)] enum UserCommand { - /// - Create a new user - Create { - /// Username of the new user - username: String, - /// Password of the new user, if unspecified one is generated - password: Option, - }, + /// - Create a new user + Create { + /// Username of the new user + username: String, + /// Password of the new user, if unspecified one is generated + password: Option, + }, - /// - Reset user password - ResetPassword { - /// Username of the user for whom the password should be reset - username: String, - }, + /// - Reset user password + ResetPassword { + /// Username of the user for whom the password should be reset + username: String, + }, - /// - Deactivate a user - /// - /// User will not be removed from all rooms by default. - /// Use --leave-rooms to force the user to leave all rooms - Deactivate { - #[arg(short, long)] - leave_rooms: bool, - user_id: Box, - }, + /// - Deactivate a user + /// + /// User will not be removed from all rooms by default. + /// Use --leave-rooms to force the user to leave all rooms + Deactivate { + #[arg(short, long)] + leave_rooms: bool, + user_id: Box, + }, - /// - Deactivate a list of users - /// - /// Recommended to use in conjunction with list-local-users. - /// - /// Users will not be removed from joined rooms by default. - /// Can be overridden with --leave-rooms flag. - /// Removing a mass amount of users from a room may cause a significant amount of leave events. - /// The time to leave rooms may depend significantly on joined rooms and servers. - /// - /// This command needs a newline separated list of users provided in a - /// Markdown code block below the command. - DeactivateAll { - #[arg(short, long)] - /// Remove users from their joined rooms - leave_rooms: bool, - #[arg(short, long)] - /// Also deactivate admin accounts - force: bool, - }, + /// - Deactivate a list of users + /// + /// Recommended to use in conjunction with list-local-users. + /// + /// Users will not be removed from joined rooms by default. + /// Can be overridden with --leave-rooms flag. + /// Removing a mass amount of users from a room may cause a significant + /// amount of leave events. The time to leave rooms may depend significantly + /// on joined rooms and servers. + /// + /// This command needs a newline separated list of users provided in a + /// Markdown code block below the command. + DeactivateAll { + #[arg(short, long)] + /// Remove users from their joined rooms + leave_rooms: bool, + #[arg(short, long)] + /// Also deactivate admin accounts + force: bool, + }, - /// - List local users in the database - List, + /// - List local users in the database + List, - /// - Lists all the rooms (local and remote) that the specified user is joined in - ListJoinedRooms { user_id: Box }, + /// - Lists all the rooms (local and remote) that the specified user is + /// joined in + ListJoinedRooms { + user_id: Box, + }, } #[cfg_attr(test, derive(Debug))] #[derive(Subcommand)] enum RoomCommand { - /// - List all rooms the server knows about - List { page: Option }, + /// - List all rooms the server knows about + List { + page: Option, + }, - #[command(subcommand)] - /// - Manage moderation of remote or local rooms - Moderation(RoomModeration), + #[command(subcommand)] + /// - Manage moderation of remote or local rooms + Moderation(RoomModeration), - #[command(subcommand)] - /// - Manage rooms' aliases - Alias(RoomAliasCommand), + #[command(subcommand)] + /// - Manage rooms' aliases + Alias(RoomAliasCommand), - #[command(subcommand)] - /// - Manage the room directory - Directory(RoomDirectoryCommand), + #[command(subcommand)] + /// - Manage the room directory + Directory(RoomDirectoryCommand), } #[cfg_attr(test, derive(Debug))] #[derive(Subcommand)] enum RoomModeration { - /// - Bans a room from local users joining and evicts all our local users from the room. Also blocks any invites (local and remote) for the banned room. - /// - /// Server admins (users in the conduwuit admin room) will not be evicted and server admins can still join the room. - /// To evict admins too, use --force (also ignores errors) - /// To disable incoming federation of the room, use --disable-federation - BanRoom { - #[arg(short, long)] - /// Evicts admins out of the room and ignores any potential errors when making our local users leave the room - force: bool, + /// - Bans a room from local users joining and evicts all our local users + /// from the room. Also blocks any invites (local and remote) for the + /// banned room. + /// + /// Server admins (users in the conduwuit admin room) will not be evicted + /// and server admins can still join the room. To evict admins too, use + /// --force (also ignores errors) To disable incoming federation of the + /// room, use --disable-federation + BanRoom { + #[arg(short, long)] + /// Evicts admins out of the room and ignores any potential errors when + /// making our local users leave the room + force: bool, - #[arg(long)] - /// Disables incoming federation of the room after banning and evicting users - disable_federation: bool, + #[arg(long)] + /// Disables incoming federation of the room after banning and evicting + /// users + disable_federation: bool, - /// The room in the format of `!roomid:example.com` or a room alias in the format of `#roomalias:example.com` - room: Box, - }, + /// The room in the format of `!roomid:example.com` or a room alias in + /// the format of `#roomalias:example.com` + room: Box, + }, - /// - Bans a list of rooms from a newline delimited codeblock similar to `user deactivate-all` - BanListOfRooms { - #[arg(short, long)] - /// Evicts admins out of the room and ignores any potential errors when making our local users leave the room - force: bool, + /// - Bans a list of rooms from a newline delimited codeblock similar to + /// `user deactivate-all` + BanListOfRooms { + #[arg(short, long)] + /// Evicts admins out of the room and ignores any potential errors when + /// making our local users leave the room + force: bool, - #[arg(long)] - /// Disables incoming federation of the room after banning and evicting users - disable_federation: bool, - }, + #[arg(long)] + /// Disables incoming federation of the room after banning and evicting + /// users + disable_federation: bool, + }, - /// - Unbans a room to allow local users to join again - /// - /// To re-enable incoming federation of the room, use --enable-federation - UnbanRoom { - #[arg(long)] - /// Enables incoming federation of the room after unbanning - enable_federation: bool, + /// - Unbans a room to allow local users to join again + /// + /// To re-enable incoming federation of the room, use --enable-federation + UnbanRoom { + #[arg(long)] + /// Enables incoming federation of the room after unbanning + enable_federation: bool, - /// The room in the format of `!roomid:example.com` or a room alias in the format of `#roomalias:example.com` - room: Box, - }, + /// The room in the format of `!roomid:example.com` or a room alias in + /// the format of `#roomalias:example.com` + room: Box, + }, - /// - List of all rooms we have banned - ListBannedRooms, + /// - List of all rooms we have banned + ListBannedRooms, } #[cfg_attr(test, derive(Debug))] #[derive(Subcommand)] enum RoomAliasCommand { - /// - Make an alias point to a room. - Set { - #[arg(short, long)] - /// Set the alias even if a room is already using it - force: bool, + /// - Make an alias point to a room. + Set { + #[arg(short, long)] + /// Set the alias even if a room is already using it + force: bool, - /// The room id to set the alias on - room_id: Box, + /// The room id to set the alias on + room_id: Box, - /// The alias localpart to use (`alias`, not `#alias:servername.tld`) - room_alias_localpart: String, - }, + /// The alias localpart to use (`alias`, not `#alias:servername.tld`) + room_alias_localpart: String, + }, - /// - Remove an alias - Remove { - /// The alias localpart to remove (`alias`, not `#alias:servername.tld`) - room_alias_localpart: String, - }, + /// - Remove an alias + Remove { + /// The alias localpart to remove (`alias`, not `#alias:servername.tld`) + room_alias_localpart: String, + }, - /// - Show which room is using an alias - Which { - /// The alias localpart to look up (`alias`, not `#alias:servername.tld`) - room_alias_localpart: String, - }, + /// - Show which room is using an alias + Which { + /// The alias localpart to look up (`alias`, not + /// `#alias:servername.tld`) + room_alias_localpart: String, + }, - /// - List aliases currently being used - List { - /// If set, only list the aliases for this room - room_id: Option>, - }, + /// - List aliases currently being used + List { + /// If set, only list the aliases for this room + room_id: Option>, + }, } #[cfg_attr(test, derive(Debug))] #[derive(Subcommand)] enum RoomDirectoryCommand { - /// - Publish a room to the room directory - Publish { - /// The room id of the room to publish - room_id: Box, - }, + /// - Publish a room to the room directory + Publish { + /// The room id of the room to publish + room_id: Box, + }, - /// - Unpublish a room to the room directory - Unpublish { - /// The room id of the room to unpublish - room_id: Box, - }, + /// - Unpublish a room to the room directory + Unpublish { + /// The room id of the room to unpublish + room_id: Box, + }, - /// - List rooms that are published - List { page: Option }, + /// - List rooms that are published + List { + page: Option, + }, } #[cfg_attr(test, derive(Debug))] #[derive(Subcommand)] enum FederationCommand { - /// - List all rooms we are currently handling an incoming pdu from - IncomingFederation, + /// - List all rooms we are currently handling an incoming pdu from + IncomingFederation, - /// - Disables incoming federation handling for a room. - DisableRoom { room_id: Box }, + /// - Disables incoming federation handling for a room. + DisableRoom { + room_id: Box, + }, - /// - Enables incoming federation handling for a room again. - EnableRoom { room_id: Box }, + /// - Enables incoming federation handling for a room again. + EnableRoom { + room_id: Box, + }, - /// - Verify json signatures - /// - /// This command needs a JSON blob provided in a Markdown code block below - /// the command. - SignJson, + /// - Verify json signatures + /// + /// This command needs a JSON blob provided in a Markdown code block below + /// the command. + SignJson, - /// - Verify json signatures - /// - /// This command needs a JSON blob provided in a Markdown code block below - /// the command. - VerifyJson, + /// - Verify json signatures + /// + /// This command needs a JSON blob provided in a Markdown code block below + /// the command. + VerifyJson, } #[cfg_attr(test, derive(Debug))] #[derive(Subcommand)] enum DebugCommand { - /// - Get the auth_chain of a PDU - GetAuthChain { - /// An event ID (the $ character followed by the base64 reference hash) - event_id: Box, - }, + /// - Get the auth_chain of a PDU + GetAuthChain { + /// An event ID (the $ character followed by the base64 reference hash) + event_id: Box, + }, - /// - Parse and print a PDU from a JSON - /// - /// The PDU event is only checked for validity and is not added to the - /// database. - /// - /// This command needs a JSON blob provided in a Markdown code block below - /// the command. - ParsePdu, + /// - Parse and print a PDU from a JSON + /// + /// The PDU event is only checked for validity and is not added to the + /// database. + /// + /// This command needs a JSON blob provided in a Markdown code block below + /// the command. + ParsePdu, - /// - Retrieve and print a PDU by ID from the Conduit database - GetPdu { - /// An event ID (a $ followed by the base64 reference hash) - event_id: Box, - }, + /// - Retrieve and print a PDU by ID from the Conduit database + GetPdu { + /// An event ID (a $ followed by the base64 reference hash) + event_id: Box, + }, - /// - Forces device lists for all the local users to be updated - ForceDeviceListUpdates, + /// - Forces device lists for all the local users to be updated + ForceDeviceListUpdates, } #[cfg_attr(test, derive(Debug))] #[derive(Subcommand)] enum ServerCommand { - /// - Show configuration values - ShowConfig, + /// - Show configuration values + ShowConfig, - /// - Print database memory usage statistics - MemoryUsage, + /// - Print database memory usage statistics + MemoryUsage, - /// - Clears all of Conduit's database caches with index smaller than the amount - ClearDatabaseCaches { amount: u32 }, + /// - Clears all of Conduit's database caches with index smaller than the + /// amount + ClearDatabaseCaches { + amount: u32, + }, - /// - Clears all of Conduit's service caches with index smaller than the amount - ClearServiceCaches { amount: u32 }, + /// - Clears all of Conduit's service caches with index smaller than the + /// amount + ClearServiceCaches { + amount: u32, + }, } #[derive(Debug)] pub enum AdminRoomEvent { - ProcessMessage(String, Arc), - SendMessage(RoomMessageEventContent), + ProcessMessage(String, Arc), + SendMessage(RoomMessageEventContent), } pub struct Service { - pub sender: mpsc::UnboundedSender, - receiver: Mutex>, + pub sender: mpsc::UnboundedSender, + receiver: Mutex>, } impl Service { - pub fn build() -> Arc { - let (sender, receiver) = mpsc::unbounded_channel(); - Arc::new(Self { - sender, - receiver: Mutex::new(receiver), - }) - } - - pub fn start_handler(self: &Arc) { - let self2 = Arc::clone(self); - tokio::spawn(async move { - self2.handler().await; - }); - } - - async fn handler(&self) { - let mut receiver = self.receiver.lock().await; - // TODO: Use futures when we have long admin commands - //let mut futures = FuturesUnordered::new(); - - let conduit_user = UserId::parse(format!("@conduit:{}", services().globals.server_name())) - .expect("@conduit:server_name is valid"); - - let conduit_room = services() - .rooms - .alias - .resolve_local_alias( - format!("#admins:{}", services().globals.server_name()) - .as_str() - .try_into() - .expect("#admins:server_name is a valid room alias"), - ) - .expect("Database data for admin room alias must be valid") - .expect("Admin room must exist"); - - loop { - tokio::select! { - Some(event) = receiver.recv() => { - let (mut message_content, reply) = match event { - AdminRoomEvent::SendMessage(content) => (content, None), - AdminRoomEvent::ProcessMessage(room_message, reply_id) => { - (self.process_admin_message(room_message).await, Some(reply_id)) - } - }; - - let mutex_state = Arc::clone( - services().globals - .roomid_mutex_state - .write() - .unwrap() - .entry(conduit_room.clone()) - .or_default(), - ); - - let state_lock = mutex_state.lock().await; - - if let Some(reply) = reply { - message_content.relates_to = Some(Reply { in_reply_to: InReplyTo { event_id: reply.into() } }); - } - - services().rooms.timeline.build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMessage, - content: to_raw_value(&message_content) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: None, - redacts: None, - }, - &conduit_user, - &conduit_room, - &state_lock) - .await - .unwrap(); - - - drop(state_lock); - } - } - } - } - - pub fn process_message(&self, room_message: String, event_id: Arc) { - self.sender - .send(AdminRoomEvent::ProcessMessage(room_message, event_id)) - .unwrap(); - } - - pub fn send_message(&self, message_content: RoomMessageEventContent) { - self.sender - .send(AdminRoomEvent::SendMessage(message_content)) - .unwrap(); - } - - // Parse and process a message from the admin room - async fn process_admin_message(&self, room_message: String) -> RoomMessageEventContent { - let mut lines = room_message.lines().filter(|l| !l.trim().is_empty()); - let command_line = lines.next().expect("each string has at least one line"); - let body: Vec<_> = lines.collect(); - - let admin_command = match self.parse_admin_command(command_line) { - Ok(command) => command, - Err(error) => { - let server_name = services().globals.server_name(); - let message = error.replace("server.name", server_name.as_str()); - let html_message = self.usage_to_html(&message, server_name); - - return RoomMessageEventContent::text_html(message, html_message); - } - }; - - match self.process_admin_command(admin_command, body).await { - Ok(reply_message) => reply_message, - Err(error) => { - let markdown_message = format!( - "Encountered an error while handling the command:\n\ - ```\n{error}\n```", - ); - let html_message = format!( - "Encountered an error while handling the command:\n\ -
\n{error}\n
", - ); - - RoomMessageEventContent::text_html(markdown_message, html_message) - } - } - } - - // Parse chat messages from the admin room into an AdminCommand object - fn parse_admin_command(&self, command_line: &str) -> std::result::Result { - // Note: argv[0] is `@conduit:servername:`, which is treated as the main command - let mut argv: Vec<_> = command_line.split_whitespace().collect(); - - // Replace `help command` with `command --help` - // Clap has a help subcommand, but it omits the long help description. - if argv.len() > 1 && argv[1] == "help" { - argv.remove(1); - argv.push("--help"); - } - - // Backwards compatibility with `register_appservice`-style commands - let command_with_dashes; - if argv.len() > 1 && argv[1].contains('_') { - command_with_dashes = argv[1].replace('_', "-"); - argv[1] = &command_with_dashes; - } - - AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) - } - - async fn process_admin_command( - &self, - command: AdminCommand, - body: Vec<&str>, - ) -> Result { - let reply_message_content = match command { - AdminCommand::Appservices(command) => match command { - AppserviceCommand::Register => { - if body.len() > 2 - && body[0].trim().starts_with("```") - && body.last().unwrap().trim() == "```" - { - let appservice_config = body[1..body.len() - 1].join("\n"); - let parsed_config = - serde_yaml::from_str::(&appservice_config); - match parsed_config { - Ok(yaml) => match services().appservice.register_appservice(yaml) { - Ok(id) => RoomMessageEventContent::text_plain(format!( - "Appservice registered with ID: {id}." - )), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Failed to register appservice: {e}" - )), - }, - Err(e) => RoomMessageEventContent::text_plain(format!( - "Could not parse appservice config: {e}" - )), - } - } else { - RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - ) - } - } - AppserviceCommand::Unregister { - appservice_identifier, - } => match services() - .appservice - .unregister_appservice(&appservice_identifier) - { - Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Failed to unregister appservice: {e}" - )), - }, - AppserviceCommand::Show { - appservice_identifier, - } => { - match services() - .appservice - .get_registration(&appservice_identifier) - { - Ok(Some(config)) => { - let config_str = serde_yaml::to_string(&config) - .expect("config should've been validated on register"); - let output = format!( - "Config for {}:\n\n```yaml\n{}\n```", - appservice_identifier, config_str, - ); - let output_html = format!( - "Config for {}:\n\n
{}
", - escape_html(&appservice_identifier), - escape_html(&config_str), - ); - RoomMessageEventContent::text_html(output, output_html) - } - Ok(None) => { - RoomMessageEventContent::text_plain("Appservice does not exist.") - } - Err(_) => RoomMessageEventContent::text_plain("Failed to get appservice."), - } - } - AppserviceCommand::List => { - if let Ok(appservices) = services() - .appservice - .iter_ids() - .map(std::iter::Iterator::collect::>) - { - let count = appservices.len(); - let output = format!( - "Appservices ({}): {}", - count, - appservices - .into_iter() - .filter_map(std::result::Result::ok) - .collect::>() - .join(", ") - ); - RoomMessageEventContent::text_plain(output) - } else { - RoomMessageEventContent::text_plain("Failed to get appservices.") - } - } - }, - AdminCommand::Media(command) => match command { - MediaCommand::Delete { mxc, event_id } => { - if event_id.is_some() && mxc.is_some() { - return Ok(RoomMessageEventContent::text_plain( - "Please specify either an MXC or an event ID, not both.", - )); - } - - if let Some(mxc) = mxc { - if !mxc.to_string().starts_with("mxc://") { - return Ok(RoomMessageEventContent::text_plain( - "MXC provided is not valid.", - )); - } - - debug!("Got MXC URL: {}", mxc); - services().media.delete(mxc.to_string()).await?; - - return Ok(RoomMessageEventContent::text_plain( - "Deleted the MXC from our database and on our filesystem.", - )); - } else if let Some(event_id) = event_id { - debug!("Got event ID to delete media from: {}", event_id); - - let mut mxc_urls: Vec = vec![]; - let mut mxc_deletion_count = 0; - - // parsing the PDU for any MXC URLs begins here - if let Some(event_json) = - services().rooms.timeline.get_pdu_json(&event_id)? - { - if let Some(content_key) = event_json.get("content") { - debug!("Event ID has \"content\"."); - let content_obj = content_key.as_object(); - - if let Some(content) = content_obj { - // 1. attempts to parse the "url" key - debug!("Attempting to go into \"url\" key for main media file"); - if let Some(url) = content.get("url") { - debug!("Got a URL in the event ID {event_id}: {url}"); - - if url.to_string().starts_with("\"mxc://") { - debug!("Pushing URL {} to list of MXCs to delete", url); - let final_url = url.to_string().replace('"', ""); - mxc_urls.push(final_url); - } else { - info!("Found a URL in the event ID {event_id} but did not start with mxc://, ignoring"); - } - } - - // 2. attempts to parse the "info" key - debug!("Attempting to go into \"info\" key for thumbnails"); - if let Some(info_key) = content.get("info") { - debug!("Event ID has \"info\"."); - let info_obj = info_key.as_object(); - - if let Some(info) = info_obj { - if let Some(thumbnail_url) = info.get("thumbnail_url") { - debug!("Found a thumbnail_url in info key: {thumbnail_url}"); - - if thumbnail_url.to_string().starts_with("\"mxc://") - { - debug!("Pushing thumbnail URL {} to list of MXCs to delete", thumbnail_url); - let final_thumbnail_url = - thumbnail_url.to_string().replace('"', ""); - mxc_urls.push(final_thumbnail_url); - } else { - info!("Found a thumbnail URL in the event ID {event_id} but did not start with mxc://, ignoring"); - } - } else { - info!("No \"thumbnail_url\" key in \"info\" key, assuming no thumbnails."); - } - } - } - - // 3. attempts to parse the "file" key - debug!("Attempting to go into \"file\" key"); - if let Some(file_key) = content.get("file") { - debug!("Event ID has \"file\"."); - let file_obj = file_key.as_object(); - - if let Some(file) = file_obj { - if let Some(url) = file.get("url") { - debug!("Found url in file key: {url}"); - - if url.to_string().starts_with("\"mxc://") { - debug!( - "Pushing URL {} to list of MXCs to delete", - url - ); - let final_url = - url.to_string().replace('"', ""); - mxc_urls.push(final_url); - } else { - info!("Found a URL in the event ID {event_id} but did not start with mxc://, ignoring"); - } - } else { - info!("No \"url\" key in \"file\" key."); - } - } - } - } else { - return Ok(RoomMessageEventContent::text_plain("Event ID does not have a \"content\" key or failed parsing the event ID JSON.")); - } - } else { - return Ok(RoomMessageEventContent::text_plain("Event ID does not have a \"content\" key, this is not a message or an event type that contains media.")); - } - } else { - return Ok(RoomMessageEventContent::text_plain( - "Event ID does not exist or is not known to us.", - )); - } - - if mxc_urls.is_empty() { - // we shouldn't get here (should have errored earlier) but just in case for whatever reason we do... - info!("Parsed event ID {event_id} but did not contain any MXC URLs."); - return Ok(RoomMessageEventContent::text_plain( - "Parsed event ID but found no MXC URLs.", - )); - } - - for mxc_url in mxc_urls { - services().media.delete(mxc_url).await?; - mxc_deletion_count += 1; - } - - return Ok(RoomMessageEventContent::text_plain(format!("Deleted {} total MXCs from our database and the filesystem from event ID {event_id}.", mxc_deletion_count))); - } else { - return Ok(RoomMessageEventContent::text_plain( - "Please specify either an MXC using --mxc or an event ID using --event-id of the message containing an image. See --help for details.", - )); - } - } - MediaCommand::DeleteList => { - if body.len() > 2 - && body[0].trim().starts_with("```") - && body.last().unwrap().trim() == "```" - { - let mxc_list = body.clone().drain(1..body.len() - 1).collect::>(); - - let mut mxc_deletion_count = 0; - - for mxc in mxc_list { - debug!("Deleting MXC {} in bulk", mxc); - services().media.delete(mxc.to_owned()).await?; - mxc_deletion_count += 1; - } - - return Ok(RoomMessageEventContent::text_plain(format!("Finished bulk MXC deletion, deleted {} total MXCs from our database and the filesystem.", mxc_deletion_count))); - } else { - return Ok(RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - )); - } - } - MediaCommand::DeletePastRemoteMedia { duration } => { - let deleted_count = services() - .media - .delete_all_remote_media_at_after_time(duration) - .await?; - - return Ok(RoomMessageEventContent::text_plain(format!( - "Deleted {} total files.", - deleted_count - ))); - } - }, - AdminCommand::Users(command) => match command { - UserCommand::List => match services().users.list_local_users() { - Ok(users) => { - let mut msg: String = - format!("Found {} local user account(s):\n", users.len()); - msg += &users.join("\n"); - RoomMessageEventContent::text_plain(&msg) - } - Err(e) => RoomMessageEventContent::text_plain(e.to_string()), - }, - UserCommand::Create { username, password } => { - let password = - password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH)); - // Validate user id - let user_id = match UserId::parse_with_server_name( - username.as_str().to_lowercase(), - services().globals.server_name(), - ) { - Ok(id) => id, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "The supplied username is not a valid username: {e}" - ))) - } - }; - if user_id.is_historical() { - return Ok(RoomMessageEventContent::text_plain(format!( - "Userid {user_id} is not allowed due to historical" - ))); - } - if services().users.exists(&user_id)? { - return Ok(RoomMessageEventContent::text_plain(format!( - "Userid {user_id} already exists" - ))); - } - // Create user - services().users.create(&user_id, Some(password.as_str()))?; - - // Default to pretty displayname - let mut displayname = user_id.localpart().to_owned(); - - // If `new_user_displayname_suffix` is set, registration will push whatever content is set to the user's display name with a space before it - if !services().globals.new_user_displayname_suffix().is_empty() { - displayname.push_str( - &(" ".to_owned() + services().globals.new_user_displayname_suffix()), - ); - } - - services() - .users - .set_displayname(&user_id, Some(displayname)) - .await?; - - // Initial account data - services().account_data.update( - None, - &user_id, - ruma::events::GlobalAccountDataEventType::PushRules - .to_string() - .into(), - &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { - content: ruma::events::push_rules::PushRulesEventContent { - global: ruma::push::Ruleset::server_default(&user_id), - }, - }) - .expect("to json value always works"), - )?; - - // we dont add a device since we're not the user, just the creator - - // Inhibit login does not work for guests - RoomMessageEventContent::text_plain(format!( - "Created user with user_id: {user_id} and password: `{password}`" - )) - } - UserCommand::Deactivate { - leave_rooms, - user_id, - } => { - let user_id = Arc::::from(user_id); - - // check if user belongs to our server - if user_id.server_name() != services().globals.server_name() { - return Ok(RoomMessageEventContent::text_plain(format!( - "User {user_id} does not belong to our server." - ))); - } - - if services().users.exists(&user_id)? { - RoomMessageEventContent::text_plain(format!( - "Making {user_id} leave all rooms before deactivation..." - )); - - services().users.deactivate_account(&user_id)?; - - if leave_rooms { - leave_all_rooms(&user_id).await?; - } - - RoomMessageEventContent::text_plain(format!( - "User {user_id} has been deactivated" - )) - } else { - RoomMessageEventContent::text_plain(format!( - "User {user_id} doesn't exist on this server" - )) - } - } - UserCommand::ResetPassword { username } => { - let user_id = match UserId::parse_with_server_name( - username.as_str().to_lowercase(), - services().globals.server_name(), - ) { - Ok(id) => id, - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "The supplied username is not a valid username: {e}" - ))) - } - }; - - // check if user belongs to our server - if user_id.server_name() != services().globals.server_name() { - return Ok(RoomMessageEventContent::text_plain(format!( - "User {user_id} does not belong to our server." - ))); - } - - // Check if the specified user is valid - if !services().users.exists(&user_id)? - || user_id - == UserId::parse_with_server_name( - "conduit", - services().globals.server_name(), - ) - .expect("conduit user exists") - { - return Ok(RoomMessageEventContent::text_plain( - "The specified user does not exist!", - )); - } - - let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); - - match services() - .users - .set_password(&user_id, Some(new_password.as_str())) - { - Ok(()) => RoomMessageEventContent::text_plain(format!( - "Successfully reset the password for user {user_id}: `{new_password}`" - )), - Err(e) => RoomMessageEventContent::text_plain(format!( - "Couldn't reset the password for user {user_id}: {e}" - )), - } - } - UserCommand::DeactivateAll { leave_rooms, force } => { - if body.len() > 2 - && body[0].trim().starts_with("```") - && body.last().unwrap().trim() == "```" - { - let usernames = body.clone().drain(1..body.len() - 1).collect::>(); - - let mut user_ids: Vec<&UserId> = Vec::new(); - - for &username in &usernames { - match <&UserId>::try_from(username) { - Ok(user_id) => user_ids.push(user_id), - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "{username} is not a valid username: {e}" - ))) - } - } - } - - let mut deactivation_count = 0; - let mut admins = Vec::new(); - - if !force { - user_ids.retain(|&user_id| match services().users.is_admin(user_id) { - Ok(is_admin) => match is_admin { - true => { - admins.push(user_id.localpart()); - false - } - false => true, - }, - Err(_) => false, - }); - } - - for &user_id in &user_ids { - // check if user belongs to our server and skips over non-local users - if user_id.server_name() != services().globals.server_name() { - continue; - } - - if services().users.deactivate_account(user_id).is_ok() { - deactivation_count += 1; - } - } - - if leave_rooms { - for &user_id in &user_ids { - let _ = leave_all_rooms(user_id).await; - } - } - - if admins.is_empty() { - RoomMessageEventContent::text_plain(format!( - "Deactivated {deactivation_count} accounts." - )) - } else { - RoomMessageEventContent::text_plain(format!("Deactivated {} accounts.\nSkipped admin accounts: {:?}. Use --force to deactivate admin accounts", deactivation_count, admins.join(", "))) - } - } else { - RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - ) - } - } - UserCommand::ListJoinedRooms { user_id } => { - if user_id.server_name() != services().globals.server_name() { - return Ok(RoomMessageEventContent::text_plain( - "User does not belong to our server.", - )); - } - - let mut rooms: Vec<(OwnedRoomId, u64, String)> = vec![]; // room ID, members joined, room name - - for room_id in services().rooms.state_cache.rooms_joined(&user_id) { - let room_id = room_id?; - rooms.push(Self::get_room_info(room_id)); - } - - if rooms.is_empty() { - return Ok(RoomMessageEventContent::text_plain( - "User is not in any rooms.", - )); - } - - rooms.sort_by_key(|r| r.1); - rooms.reverse(); - - let output_plain = format!( - "Rooms {user_id} Joined:\n{}", - rooms - .iter() - .map(|(id, members, name)| format!( - "{id}\tMembers: {members}\tName: {name}" - )) - .collect::>() - .join("\n") - ); - let output_html = format!( - "\n\t\t\n{}
Rooms {user_id} Joined
idmembersname
", - rooms - .iter() - .fold(String::new(), |mut output, (id, members, name)| { - writeln!(output, "{}\t{}\t{}", escape_html(id.as_ref()), - members, - escape_html(name)).unwrap(); - output - }) - ); - RoomMessageEventContent::text_html(output_plain, output_html) - } - }, - AdminCommand::Rooms(command) => match command { - RoomCommand::Moderation(command) => match command { - RoomModeration::BanRoom { - force, - room, - disable_federation, - } => { - debug!("Got room alias or ID: {}", room); - - let admin_room_alias: Box = - format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - let admin_room_id = services() - .rooms - .alias - .resolve_local_alias(&admin_room_alias)? - .expect("Admin room must exist"); - - if room.to_string().eq(&admin_room_id) - || room.to_string().eq(&admin_room_alias) - { - return Ok(RoomMessageEventContent::text_plain( - "Not allowed to ban the admin room.", - )); - } - - let room_id = if room.is_room_id() { - let room_id = match RoomId::parse(&room) { - Ok(room_id) => room_id, - Err(e) => return Ok(RoomMessageEventContent::text_plain(format!("Failed to parse room ID {room}. Please note that this requires a full room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}"))), - }; - - debug!("Room specified is a room ID, banning room ID"); - - services().rooms.metadata.ban_room(&room_id, true)?; - - room_id - } else if room.is_room_alias_id() { - let room_alias = match RoomAliasId::parse(&room) { - Ok(room_alias) => room_alias, - Err(e) => return Ok(RoomMessageEventContent::text_plain(format!("Failed to parse room ID {room}. Please note that this requires a full room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}"))), - }; - - debug!("Room specified is not a room ID, attempting to resolve room alias to a room ID locally, if not using get_alias_helper to fetch room ID remotely"); - - let room_id = match services() - .rooms - .alias - .resolve_local_alias(&room_alias)? - { - Some(room_id) => room_id, - None => { - debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); - - match get_alias_helper(room_alias).await { - Ok(response) => { - debug!("Got federation response fetching room ID for room {room}: {:?}", response); - response.room_id - } - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!("Failed to resolve room alias {room} to a room ID: {e}"))); - } - } - } - }; - - services().rooms.metadata.ban_room(&room_id, true)?; - - room_id - } else { - return Ok(RoomMessageEventContent::text_plain("Room specified is not a room ID or room alias. Please note that this requires a full room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`)")); - }; - - debug!("Making all users leave the room {}", &room); - if force { - for local_user in services() - .rooms - .state_cache - .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == services().globals.server_name() + pub fn build() -> Arc { + let (sender, receiver) = mpsc::unbounded_channel(); + Arc::new(Self { + sender, + receiver: Mutex::new(receiver), + }) + } + + pub fn start_handler(self: &Arc) { + let self2 = Arc::clone(self); + tokio::spawn(async move { + self2.handler().await; + }); + } + + async fn handler(&self) { + let mut receiver = self.receiver.lock().await; + // TODO: Use futures when we have long admin commands + //let mut futures = FuturesUnordered::new(); + + let conduit_user = UserId::parse(format!("@conduit:{}", services().globals.server_name())) + .expect("@conduit:server_name is valid"); + + let conduit_room = services() + .rooms + .alias + .resolve_local_alias( + format!("#admins:{}", services().globals.server_name()) + .as_str() + .try_into() + .expect("#admins:server_name is a valid room alias"), + ) + .expect("Database data for admin room alias must be valid") + .expect("Admin room must exist"); + + loop { + tokio::select! { + Some(event) = receiver.recv() => { + let (mut message_content, reply) = match event { + AdminRoomEvent::SendMessage(content) => (content, None), + AdminRoomEvent::ProcessMessage(room_message, reply_id) => { + (self.process_admin_message(room_message).await, Some(reply_id)) + } + }; + + let mutex_state = Arc::clone( + services().globals + .roomid_mutex_state + .write() + .unwrap() + .entry(conduit_room.clone()) + .or_default(), + ); + + let state_lock = mutex_state.lock().await; + + if let Some(reply) = reply { + message_content.relates_to = Some(Reply { in_reply_to: InReplyTo { event_id: reply.into() } }); + } + + services().rooms.timeline.build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMessage, + content: to_raw_value(&message_content) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: None, + redacts: None, + }, + &conduit_user, + &conduit_room, + &state_lock) + .await + .unwrap(); + + + drop(state_lock); + } + } + } + } + + pub fn process_message(&self, room_message: String, event_id: Arc) { + self.sender.send(AdminRoomEvent::ProcessMessage(room_message, event_id)).unwrap(); + } + + pub fn send_message(&self, message_content: RoomMessageEventContent) { + self.sender.send(AdminRoomEvent::SendMessage(message_content)).unwrap(); + } + + // Parse and process a message from the admin room + async fn process_admin_message(&self, room_message: String) -> RoomMessageEventContent { + let mut lines = room_message.lines().filter(|l| !l.trim().is_empty()); + let command_line = lines.next().expect("each string has at least one line"); + let body: Vec<_> = lines.collect(); + + let admin_command = match self.parse_admin_command(command_line) { + Ok(command) => command, + Err(error) => { + let server_name = services().globals.server_name(); + let message = error.replace("server.name", server_name.as_str()); + let html_message = self.usage_to_html(&message, server_name); + + return RoomMessageEventContent::text_html(message, html_message); + }, + }; + + match self.process_admin_command(admin_command, body).await { + Ok(reply_message) => reply_message, + Err(error) => { + let markdown_message = format!("Encountered an error while handling the command:\n```\n{error}\n```",); + let html_message = format!("Encountered an error while handling the command:\n
\n{error}\n
",); + + RoomMessageEventContent::text_html(markdown_message, html_message) + }, + } + } + + // Parse chat messages from the admin room into an AdminCommand object + fn parse_admin_command(&self, command_line: &str) -> std::result::Result { + // Note: argv[0] is `@conduit:servername:`, which is treated as the main command + let mut argv: Vec<_> = command_line.split_whitespace().collect(); + + // Replace `help command` with `command --help` + // Clap has a help subcommand, but it omits the long help description. + if argv.len() > 1 && argv[1] == "help" { + argv.remove(1); + argv.push("--help"); + } + + // Backwards compatibility with `register_appservice`-style commands + let command_with_dashes; + if argv.len() > 1 && argv[1].contains('_') { + command_with_dashes = argv[1].replace('_', "-"); + argv[1] = &command_with_dashes; + } + + AdminCommand::try_parse_from(argv).map_err(|error| error.to_string()) + } + + async fn process_admin_command(&self, command: AdminCommand, body: Vec<&str>) -> Result { + let reply_message_content = match command { + AdminCommand::Appservices(command) => match command { + AppserviceCommand::Register => { + if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" { + let appservice_config = body[1..body.len() - 1].join("\n"); + let parsed_config = serde_yaml::from_str::(&appservice_config); + match parsed_config { + Ok(yaml) => match services().appservice.register_appservice(yaml) { + Ok(id) => { + RoomMessageEventContent::text_plain(format!("Appservice registered with ID: {id}.")) + }, + Err(e) => { + RoomMessageEventContent::text_plain(format!("Failed to register appservice: {e}")) + }, + }, + Err(e) => { + RoomMessageEventContent::text_plain(format!("Could not parse appservice config: {e}")) + }, + } + } else { + RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + ) + } + }, + AppserviceCommand::Unregister { + appservice_identifier, + } => match services().appservice.unregister_appservice(&appservice_identifier) { + Ok(()) => RoomMessageEventContent::text_plain("Appservice unregistered."), + Err(e) => RoomMessageEventContent::text_plain(format!("Failed to unregister appservice: {e}")), + }, + AppserviceCommand::Show { + appservice_identifier, + } => match services().appservice.get_registration(&appservice_identifier) { + Ok(Some(config)) => { + let config_str = + serde_yaml::to_string(&config).expect("config should've been validated on register"); + let output = format!("Config for {}:\n\n```yaml\n{}\n```", appservice_identifier, config_str,); + let output_html = format!( + "Config for {}:\n\n
{}
", + escape_html(&appservice_identifier), + escape_html(&config_str), + ); + RoomMessageEventContent::text_html(output, output_html) + }, + Ok(None) => RoomMessageEventContent::text_plain("Appservice does not exist."), + Err(_) => RoomMessageEventContent::text_plain("Failed to get appservice."), + }, + AppserviceCommand::List => { + if let Ok(appservices) = + services().appservice.iter_ids().map(std::iter::Iterator::collect::>) + { + let count = appservices.len(); + let output = format!( + "Appservices ({}): {}", + count, + appservices.into_iter().filter_map(std::result::Result::ok).collect::>().join(", ") + ); + RoomMessageEventContent::text_plain(output) + } else { + RoomMessageEventContent::text_plain("Failed to get appservices.") + } + }, + }, + AdminCommand::Media(command) => { + match command { + MediaCommand::Delete { + mxc, + event_id, + } => { + if event_id.is_some() && mxc.is_some() { + return Ok(RoomMessageEventContent::text_plain( + "Please specify either an MXC or an event ID, not both.", + )); + } + + if let Some(mxc) = mxc { + if !mxc.to_string().starts_with("mxc://") { + return Ok(RoomMessageEventContent::text_plain("MXC provided is not valid.")); + } + + debug!("Got MXC URL: {}", mxc); + services().media.delete(mxc.to_string()).await?; + + return Ok(RoomMessageEventContent::text_plain( + "Deleted the MXC from our database and on our filesystem.", + )); + } else if let Some(event_id) = event_id { + debug!("Got event ID to delete media from: {}", event_id); + + let mut mxc_urls: Vec = vec![]; + let mut mxc_deletion_count = 0; + + // parsing the PDU for any MXC URLs begins here + if let Some(event_json) = services().rooms.timeline.get_pdu_json(&event_id)? { + if let Some(content_key) = event_json.get("content") { + debug!("Event ID has \"content\"."); + let content_obj = content_key.as_object(); + + if let Some(content) = content_obj { + // 1. attempts to parse the "url" key + debug!("Attempting to go into \"url\" key for main media file"); + if let Some(url) = content.get("url") { + debug!("Got a URL in the event ID {event_id}: {url}"); + + if url.to_string().starts_with("\"mxc://") { + debug!("Pushing URL {} to list of MXCs to delete", url); + let final_url = url.to_string().replace('"', ""); + mxc_urls.push(final_url); + } else { + info!( + "Found a URL in the event ID {event_id} but did not start with \ + mxc://, ignoring" + ); + } + } + + // 2. attempts to parse the "info" key + debug!("Attempting to go into \"info\" key for thumbnails"); + if let Some(info_key) = content.get("info") { + debug!("Event ID has \"info\"."); + let info_obj = info_key.as_object(); + + if let Some(info) = info_obj { + if let Some(thumbnail_url) = info.get("thumbnail_url") { + debug!("Found a thumbnail_url in info key: {thumbnail_url}"); + + if thumbnail_url.to_string().starts_with("\"mxc://") { + debug!( + "Pushing thumbnail URL {} to list of MXCs to delete", + thumbnail_url + ); + let final_thumbnail_url = + thumbnail_url.to_string().replace('"', ""); + mxc_urls.push(final_thumbnail_url); + } else { + info!( + "Found a thumbnail URL in the event ID {event_id} but did \ + not start with mxc://, ignoring" + ); + } + } else { + info!( + "No \"thumbnail_url\" key in \"info\" key, assuming no \ + thumbnails." + ); + } + } + } + + // 3. attempts to parse the "file" key + debug!("Attempting to go into \"file\" key"); + if let Some(file_key) = content.get("file") { + debug!("Event ID has \"file\"."); + let file_obj = file_key.as_object(); + + if let Some(file) = file_obj { + if let Some(url) = file.get("url") { + debug!("Found url in file key: {url}"); + + if url.to_string().starts_with("\"mxc://") { + debug!("Pushing URL {} to list of MXCs to delete", url); + let final_url = url.to_string().replace('"', ""); + mxc_urls.push(final_url); + } else { + info!( + "Found a URL in the event ID {event_id} but did not start \ + with mxc://, ignoring" + ); + } + } else { + info!("No \"url\" key in \"file\" key."); + } + } + } + } else { + return Ok(RoomMessageEventContent::text_plain( + "Event ID does not have a \"content\" key or failed parsing the event ID \ + JSON.", + )); + } + } else { + return Ok(RoomMessageEventContent::text_plain( + "Event ID does not have a \"content\" key, this is not a message or an event \ + type that contains media.", + )); + } + } else { + return Ok(RoomMessageEventContent::text_plain( + "Event ID does not exist or is not known to us.", + )); + } + + if mxc_urls.is_empty() { + // we shouldn't get here (should have errored earlier) but just in case for + // whatever reason we do... + info!("Parsed event ID {event_id} but did not contain any MXC URLs."); + return Ok(RoomMessageEventContent::text_plain( + "Parsed event ID but found no MXC URLs.", + )); + } + + for mxc_url in mxc_urls { + services().media.delete(mxc_url).await?; + mxc_deletion_count += 1; + } + + return Ok(RoomMessageEventContent::text_plain(format!( + "Deleted {} total MXCs from our database and the filesystem from event ID {event_id}.", + mxc_deletion_count + ))); + } else { + return Ok(RoomMessageEventContent::text_plain( + "Please specify either an MXC using --mxc or an event ID using --event-id of the \ + message containing an image. See --help for details.", + )); + } + }, + MediaCommand::DeleteList => { + if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" { + let mxc_list = body.clone().drain(1..body.len() - 1).collect::>(); + + let mut mxc_deletion_count = 0; + + for mxc in mxc_list { + debug!("Deleting MXC {} in bulk", mxc); + services().media.delete(mxc.to_owned()).await?; + mxc_deletion_count += 1; + } + + return Ok(RoomMessageEventContent::text_plain(format!( + "Finished bulk MXC deletion, deleted {} total MXCs from our database and the \ + filesystem.", + mxc_deletion_count + ))); + } else { + return Ok(RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + )); + } + }, + MediaCommand::DeletePastRemoteMedia { + duration, + } => { + let deleted_count = services().media.delete_all_remote_media_at_after_time(duration).await?; + + return Ok(RoomMessageEventContent::text_plain(format!( + "Deleted {} total files.", + deleted_count + ))); + }, + } + }, + AdminCommand::Users(command) => match command { + UserCommand::List => match services().users.list_local_users() { + Ok(users) => { + let mut msg: String = format!("Found {} local user account(s):\n", users.len()); + msg += &users.join("\n"); + RoomMessageEventContent::text_plain(&msg) + }, + Err(e) => RoomMessageEventContent::text_plain(e.to_string()), + }, + UserCommand::Create { + username, + password, + } => { + let password = password.unwrap_or_else(|| utils::random_string(AUTO_GEN_PASSWORD_LENGTH)); + // Validate user id + let user_id = match UserId::parse_with_server_name( + username.as_str().to_lowercase(), + services().globals.server_name(), + ) { + Ok(id) => id, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "The supplied username is not a valid username: {e}" + ))) + }, + }; + if user_id.is_historical() { + return Ok(RoomMessageEventContent::text_plain(format!( + "Userid {user_id} is not allowed due to historical" + ))); + } + if services().users.exists(&user_id)? { + return Ok(RoomMessageEventContent::text_plain(format!("Userid {user_id} already exists"))); + } + // Create user + services().users.create(&user_id, Some(password.as_str()))?; + + // Default to pretty displayname + let mut displayname = user_id.localpart().to_owned(); + + // If `new_user_displayname_suffix` is set, registration will push whatever + // content is set to the user's display name with a space before it + if !services().globals.new_user_displayname_suffix().is_empty() { + displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix())); + } + + services().users.set_displayname(&user_id, Some(displayname)).await?; + + // Initial account data + services().account_data.update( + None, + &user_id, + ruma::events::GlobalAccountDataEventType::PushRules.to_string().into(), + &serde_json::to_value(ruma::events::push_rules::PushRulesEvent { + content: ruma::events::push_rules::PushRulesEventContent { + global: ruma::push::Ruleset::server_default(&user_id), + }, + }) + .expect("to json value always works"), + )?; + + // we dont add a device since we're not the user, just the creator + + // Inhibit login does not work for guests + RoomMessageEventContent::text_plain(format!( + "Created user with user_id: {user_id} and password: `{password}`" + )) + }, + UserCommand::Deactivate { + leave_rooms, + user_id, + } => { + let user_id = Arc::::from(user_id); + + // check if user belongs to our server + if user_id.server_name() != services().globals.server_name() { + return Ok(RoomMessageEventContent::text_plain(format!( + "User {user_id} does not belong to our server." + ))); + } + + if services().users.exists(&user_id)? { + RoomMessageEventContent::text_plain(format!( + "Making {user_id} leave all rooms before deactivation..." + )); + + services().users.deactivate_account(&user_id)?; + + if leave_rooms { + leave_all_rooms(&user_id).await?; + } + + RoomMessageEventContent::text_plain(format!("User {user_id} has been deactivated")) + } else { + RoomMessageEventContent::text_plain(format!("User {user_id} doesn't exist on this server")) + } + }, + UserCommand::ResetPassword { + username, + } => { + let user_id = match UserId::parse_with_server_name( + username.as_str().to_lowercase(), + services().globals.server_name(), + ) { + Ok(id) => id, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "The supplied username is not a valid username: {e}" + ))) + }, + }; + + // check if user belongs to our server + if user_id.server_name() != services().globals.server_name() { + return Ok(RoomMessageEventContent::text_plain(format!( + "User {user_id} does not belong to our server." + ))); + } + + // Check if the specified user is valid + if !services().users.exists(&user_id)? + || user_id + == UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("conduit user exists") + { + return Ok(RoomMessageEventContent::text_plain("The specified user does not exist!")); + } + + let new_password = utils::random_string(AUTO_GEN_PASSWORD_LENGTH); + + match services().users.set_password(&user_id, Some(new_password.as_str())) { + Ok(()) => RoomMessageEventContent::text_plain(format!( + "Successfully reset the password for user {user_id}: `{new_password}`" + )), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Couldn't reset the password for user {user_id}: {e}" + )), + } + }, + UserCommand::DeactivateAll { + leave_rooms, + force, + } => { + if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" { + let usernames = body.clone().drain(1..body.len() - 1).collect::>(); + + let mut user_ids: Vec<&UserId> = Vec::new(); + + for &username in &usernames { + match <&UserId>::try_from(username) { + Ok(user_id) => user_ids.push(user_id), + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "{username} is not a valid username: {e}" + ))) + }, + } + } + + let mut deactivation_count = 0; + let mut admins = Vec::new(); + + if !force { + user_ids.retain(|&user_id| match services().users.is_admin(user_id) { + Ok(is_admin) => match is_admin { + true => { + admins.push(user_id.localpart()); + false + }, + false => true, + }, + Err(_) => false, + }); + } + + for &user_id in &user_ids { + // check if user belongs to our server and skips over non-local users + if user_id.server_name() != services().globals.server_name() { + continue; + } + + if services().users.deactivate_account(user_id).is_ok() { + deactivation_count += 1; + } + } + + if leave_rooms { + for &user_id in &user_ids { + let _ = leave_all_rooms(user_id).await; + } + } + + if admins.is_empty() { + RoomMessageEventContent::text_plain(format!("Deactivated {deactivation_count} accounts.")) + } else { + RoomMessageEventContent::text_plain(format!( + "Deactivated {} accounts.\nSkipped admin accounts: {:?}. Use --force to deactivate \ + admin accounts", + deactivation_count, + admins.join(", ") + )) + } + } else { + RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + ) + } + }, + UserCommand::ListJoinedRooms { + user_id, + } => { + if user_id.server_name() != services().globals.server_name() { + return Ok(RoomMessageEventContent::text_plain("User does not belong to our server.")); + } + + let mut rooms: Vec<(OwnedRoomId, u64, String)> = vec![]; // room ID, members joined, room name + + for room_id in services().rooms.state_cache.rooms_joined(&user_id) { + let room_id = room_id?; + rooms.push(Self::get_room_info(room_id)); + } + + if rooms.is_empty() { + return Ok(RoomMessageEventContent::text_plain("User is not in any rooms.")); + } + + rooms.sort_by_key(|r| r.1); + rooms.reverse(); + + let output_plain = format!( + "Rooms {user_id} Joined:\n{}", + rooms + .iter() + .map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}")) + .collect::>() + .join("\n") + ); + let output_html = format!( + "\n\t\t\n{}
Rooms {user_id} \ + Joined
idmembersname
", + rooms.iter().fold(String::new(), |mut output, (id, members, name)| { + writeln!( + output, + "{}\t{}\t{}", + escape_html(id.as_ref()), + members, + escape_html(name) + ) + .unwrap(); + output + }) + ); + RoomMessageEventContent::text_html(output_plain, output_html) + }, + }, + AdminCommand::Rooms(command) => match command { + RoomCommand::Moderation(command) => { + match command { + RoomModeration::BanRoom { + force, + room, + disable_federation, + } => { + debug!("Got room alias or ID: {}", room); + + let admin_room_alias: Box = + format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); + let admin_room_id = services() + .rooms + .alias + .resolve_local_alias(&admin_room_alias)? + .expect("Admin room must exist"); + + if room.to_string().eq(&admin_room_id) || room.to_string().eq(&admin_room_alias) { + return Ok(RoomMessageEventContent::text_plain("Not allowed to ban the admin room.")); + } + + let room_id = if room.is_room_id() { + let room_id = match RoomId::parse(&room) { + Ok(room_id) => room_id, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to parse room ID {room}. Please note that this requires a full \ + room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias \ + (`#roomalias:example.com`): {e}" + ))) + }, + }; + + debug!("Room specified is a room ID, banning room ID"); + + services().rooms.metadata.ban_room(&room_id, true)?; + + room_id + } else if room.is_room_alias_id() { + let room_alias = match RoomAliasId::parse(&room) { + Ok(room_alias) => room_alias, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to parse room ID {room}. Please note that this requires a full \ + room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias \ + (`#roomalias:example.com`): {e}" + ))) + }, + }; + + debug!( + "Room specified is not a room ID, attempting to resolve room alias to a room ID \ + locally, if not using get_alias_helper to fetch room ID remotely" + ); + + let room_id = match services().rooms.alias.resolve_local_alias(&room_alias)? { + Some(room_id) => room_id, + None => { + debug!( + "We don't have this room alias to a room ID locally, attempting to fetch \ + room ID over federation" + ); + + match get_alias_helper(room_alias).await { + Ok(response) => { + debug!( + "Got federation response fetching room ID for room {room}: {:?}", + response + ); + response.room_id + }, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to resolve room alias {room} to a room ID: {e}" + ))); + }, + } + }, + }; + + services().rooms.metadata.ban_room(&room_id, true)?; + + room_id + } else { + return Ok(RoomMessageEventContent::text_plain( + "Room specified is not a room ID or room alias. Please note that this requires a \ + full room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias \ + (`#roomalias:example.com`)", + )); + }; + + debug!("Making all users leave the room {}", &room); + if force { + for local_user in services() + .rooms + .state_cache + .room_members(&room_id) + .filter_map(|user| { + user.ok().filter(|local_user| { + local_user.server_name() == services().globals.server_name() // additional wrapped check here is to avoid adding remote users // who are in the admin room to the list of local users (would fail auth check) && (local_user.server_name() @@ -1163,24 +1184,25 @@ impl Service { .users .is_admin(local_user) .unwrap_or(true)) // since this is a force operation, assume user is an admin if somehow this fails - }) - }) - .collect::>() - { - debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, &room_id - ); - let _ = leave_room(&local_user, &room_id, None).await; - } - } else { - for local_user in services() - .rooms - .state_cache - .room_members(&room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == services().globals.server_name() + }) + }) + .collect::>() + { + debug!( + "Attempting leave for user {} in room {} (forced, ignoring all errors, \ + evicting admins too)", + &local_user, &room_id + ); + let _ = leave_room(&local_user, &room_id, None).await; + } + } else { + for local_user in services() + .rooms + .state_cache + .room_members(&room_id) + .filter_map(|user| { + user.ok().filter(|local_user| { + local_user.server_name() == services().globals.server_name() // additional wrapped check here is to avoid adding remote users // who are in the admin room to the list of local users (would fail auth check) && (local_user.server_name() @@ -1189,89 +1211,107 @@ impl Service { .users .is_admin(local_user) .unwrap_or(false)) - }) - }) - .collect::>() - { - debug!( - "Attempting leave for user {} in room {}", - &local_user, &room_id - ); - if let Err(e) = leave_room(&local_user, &room_id, None).await { - error!("Error attempting to make local user {} leave room {} during room banning: {}", &local_user, &room_id, e); - return Ok(RoomMessageEventContent::text_plain(format!("Error attempting to make local user {} leave room {} during room banning (room is still banned but not removing any more users): {}\nIf you would like to ignore errors, use --force", &local_user, &room_id, e))); - } - } - } + }) + }) + .collect::>() + { + debug!("Attempting leave for user {} in room {}", &local_user, &room_id); + if let Err(e) = leave_room(&local_user, &room_id, None).await { + error!( + "Error attempting to make local user {} leave room {} during room \ + banning: {}", + &local_user, &room_id, e + ); + return Ok(RoomMessageEventContent::text_plain(format!( + "Error attempting to make local user {} leave room {} during room banning \ + (room is still banned but not removing any more users): {}\nIf you would \ + like to ignore errors, use --force", + &local_user, &room_id, e + ))); + } + } + } - if disable_federation { - services().rooms.metadata.disable_room(&room_id, true)?; - return Ok(RoomMessageEventContent::text_plain("Room banned, removed all our local users, and disabled incoming federation with room.")); - } + if disable_federation { + services().rooms.metadata.disable_room(&room_id, true)?; + return Ok(RoomMessageEventContent::text_plain( + "Room banned, removed all our local users, and disabled incoming federation with \ + room.", + )); + } - RoomMessageEventContent::text_plain("Room banned and removed all our local users, use disable-room to stop receiving new inbound federation events as well if needed.") - } - RoomModeration::BanListOfRooms { - force, - disable_federation, - } => { - if body.len() > 2 - && body[0].trim().starts_with("```") - && body.last().unwrap().trim() == "```" - { - let rooms_s = body.clone().drain(1..body.len() - 1).collect::>(); + RoomMessageEventContent::text_plain( + "Room banned and removed all our local users, use disable-room to stop receiving new \ + inbound federation events as well if needed.", + ) + }, + RoomModeration::BanListOfRooms { + force, + disable_federation, + } => { + if body.len() > 2 + && body[0].trim().starts_with("```") + && body.last().unwrap().trim() == "```" + { + let rooms_s = body.clone().drain(1..body.len() - 1).collect::>(); - let mut room_ban_count = 0; - let mut room_ids: Vec<&RoomId> = Vec::new(); + let mut room_ban_count = 0; + let mut room_ids: Vec<&RoomId> = Vec::new(); - let admin_room_alias: Box = - format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - let admin_room_id = services() - .rooms - .alias - .resolve_local_alias(&admin_room_alias)? - .expect("Admin room must exist"); + let admin_room_alias: Box = + format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); + let admin_room_id = services() + .rooms + .alias + .resolve_local_alias(&admin_room_alias)? + .expect("Admin room must exist"); - for &room_id in &rooms_s { - match <&RoomId>::try_from(room_id) { - Ok(owned_room_id) => { - // silently ignore deleting admin room - if owned_room_id.eq(&admin_room_id) { - info!("User specified admin room in bulk ban list, ignoring"); - continue; - } + for &room_id in &rooms_s { + match <&RoomId>::try_from(room_id) { + Ok(owned_room_id) => { + // silently ignore deleting admin room + if owned_room_id.eq(&admin_room_id) { + info!("User specified admin room in bulk ban list, ignoring"); + continue; + } - room_ids.push(owned_room_id); - } - Err(e) => { - if force { - // ignore rooms we failed to parse if we're force deleting - error!("Error parsing room ID {room_id} during bulk room banning, ignoring error and logging here: {e}"); - continue; - } else { - return Ok(RoomMessageEventContent::text_plain(format!("{room_id} is not a valid room ID, please fix the list and try again: {e}"))); - } - } - } - } + room_ids.push(owned_room_id); + }, + Err(e) => { + if force { + // ignore rooms we failed to parse if we're force deleting + error!( + "Error parsing room ID {room_id} during bulk room banning, \ + ignoring error and logging here: {e}" + ); + continue; + } else { + return Ok(RoomMessageEventContent::text_plain(format!( + "{room_id} is not a valid room ID, please fix the list and try \ + again: {e}" + ))); + } + }, + } + } - for room_id in room_ids { - if services().rooms.metadata.ban_room(room_id, true).is_ok() { - debug!("Banned {room_id} successfully"); - room_ban_count += 1; - } + for room_id in room_ids { + if services().rooms.metadata.ban_room(room_id, true).is_ok() { + debug!("Banned {room_id} successfully"); + room_ban_count += 1; + } - debug!("Making all users leave the room {}", &room_id); - if force { - for local_user in services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == services().globals.server_name() + debug!("Making all users leave the room {}", &room_id); + if force { + for local_user in services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(|user| { + user.ok().filter(|local_user| { + local_user.server_name() == services().globals.server_name() // additional wrapped check here is to avoid adding remote users // who are in the admin room to the list of local users (would fail auth check) && (local_user.server_name() @@ -1280,24 +1320,25 @@ impl Service { .users .is_admin(local_user) .unwrap_or(true)) // since this is a force operation, assume user is an admin if somehow this fails - }) - }) - .collect::>() - { - debug!( - "Attempting leave for user {} in room {} (forced, ignoring all errors, evicting admins too)", - &local_user, room_id - ); - let _ = leave_room(&local_user, room_id, None).await; - } - } else { - for local_user in services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(|user| { - user.ok().filter(|local_user| { - local_user.server_name() == services().globals.server_name() + }) + }) + .collect::>() + { + debug!( + "Attempting leave for user {} in room {} (forced, ignoring all \ + errors, evicting admins too)", + &local_user, room_id + ); + let _ = leave_room(&local_user, room_id, None).await; + } + } else { + for local_user in services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(|user| { + user.ok().filter(|local_user| { + local_user.server_name() == services().globals.server_name() // additional wrapped check here is to avoid adding remote users // who are in the admin room to the list of local users (would fail auth check) && (local_user.server_name() @@ -1306,1116 +1347,1039 @@ impl Service { .users .is_admin(local_user) .unwrap_or(false)) - }) - }) - .collect::>() - { - debug!( - "Attempting leave for user {} in room {}", - &local_user, &room_id - ); - if let Err(e) = leave_room(&local_user, room_id, None).await { - error!("Error attempting to make local user {} leave room {} during bulk room banning: {}", &local_user, &room_id, e); - return Ok(RoomMessageEventContent::text_plain(format!("Error attempting to make local user {} leave room {} during room banning (room is still banned but not removing any more users and not banning any more rooms): {}\nIf you would like to ignore errors, use --force", &local_user, &room_id, e))); - } - } - } - - if disable_federation { - services().rooms.metadata.disable_room(room_id, true)?; - } - } - - if disable_federation { - return Ok(RoomMessageEventContent::text_plain(format!("Finished bulk room ban, banned {} total rooms, evicted all users, and disabled incoming federation with the room.", room_ban_count))); - } else { - return Ok(RoomMessageEventContent::text_plain(format!("Finished bulk room ban, banned {} total rooms and evicted all users.", room_ban_count))); - } - } else { - return Ok(RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - )); - } - } - RoomModeration::UnbanRoom { - room, - enable_federation, - } => { - let room_id = if room.is_room_id() { - let room_id = match RoomId::parse(&room) { - Ok(room_id) => room_id, - Err(e) => return Ok(RoomMessageEventContent::text_plain(format!("Failed to parse room ID {room}. Please note that this requires a full room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}"))), - }; - - debug!("Room specified is a room ID, unbanning room ID"); - - services().rooms.metadata.ban_room(&room_id, false)?; - - room_id - } else if room.is_room_alias_id() { - let room_alias = match RoomAliasId::parse(&room) { - Ok(room_alias) => room_alias, - Err(e) => return Ok(RoomMessageEventContent::text_plain(format!("Failed to parse room ID {room}. Please note that this requires a full room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`): {e}"))), - }; - - debug!("Room specified is not a room ID, attempting to resolve room alias to a room ID locally, if not using get_alias_helper to fetch room ID remotely"); - - let room_id = match services() - .rooms - .alias - .resolve_local_alias(&room_alias)? - { - Some(room_id) => room_id, - None => { - debug!("We don't have this room alias to a room ID locally, attempting to fetch room ID over federation"); - - match get_alias_helper(room_alias).await { - Ok(response) => { - debug!("Got federation response fetching room ID for room {room}: {:?}", response); - response.room_id - } - Err(e) => { - return Ok(RoomMessageEventContent::text_plain(format!("Failed to resolve room alias {room} to a room ID: {e}"))); - } - } - } - }; - - services().rooms.metadata.ban_room(&room_id, false)?; - - room_id - } else { - return Ok(RoomMessageEventContent::text_plain("Room specified is not a room ID or room alias. Please note that this requires a full room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias (`#roomalias:example.com`)")); - }; - - if enable_federation { - services().rooms.metadata.disable_room(&room_id, false)?; - return Ok(RoomMessageEventContent::text_plain("Room unbanned.")); - } - - RoomMessageEventContent::text_plain("Room unbanned, you may need to re-enable federation with the room using enable-room if this is a remote room to make it fully functional.") - } - RoomModeration::ListBannedRooms => { - let rooms: Result, _> = - services().rooms.metadata.list_banned_rooms().collect(); - - match rooms { - Ok(room_ids) => { - // TODO: add room name from our state cache if available, default to the room ID as the room name if we dont have it - // TODO: do same if we have a room alias for this - let plain_list = - room_ids.iter().fold(String::new(), |mut output, room_id| { - writeln!(output, "- `{}`", room_id).unwrap(); - output - }); - - let html_list = - room_ids.iter().fold(String::new(), |mut output, room_id| { - writeln!( - output, - "
  • {}
  • ", - escape_html(room_id.as_ref()) - ) - .unwrap(); - output - }); - - let plain = format!("Rooms:\n{}", plain_list); - let html = format!("Rooms:\n
      {}
    ", html_list); - RoomMessageEventContent::text_html(plain, html) - } - Err(e) => { - error!("Failed to list banned rooms: {}", e); - RoomMessageEventContent::text_plain(format!( - "Unable to list room aliases: {}", - e - )) - } - } - } - }, - RoomCommand::List { page } => { - // TODO: i know there's a way to do this with clap, but i can't seem to find it - let page = page.unwrap_or(1); - let mut rooms = services() - .rooms - .metadata - .iter_ids() - .filter_map(std::result::Result::ok) - .map(Self::get_room_info) - .collect::>(); - rooms.sort_by_key(|r| r.1); - rooms.reverse(); - - let rooms: Vec<_> = rooms - .into_iter() - .skip(page.saturating_sub(1) * PAGE_SIZE) - .take(PAGE_SIZE) - .collect(); - - if rooms.is_empty() { - return Ok(RoomMessageEventContent::text_plain("No more rooms.")); - }; - - let output_plain = format!( - "Rooms:\n{}", - rooms - .iter() - .map(|(id, members, name)| format!( - "{id}\tMembers: {members}\tName: {name}" - )) - .collect::>() - .join("\n") - ); - let output_html = format!( - "\n\t\t\n{}
    Room list - page {page}
    idmembersname
    ", - rooms - .iter() - .fold(String::new(), |mut output, (id, members, name)| { - writeln!(output, "{}\t{}\t{}", escape_html(id.as_ref()), - members, - escape_html(name)).unwrap(); - output - }) - ); - RoomMessageEventContent::text_html(output_plain, output_html) - } - RoomCommand::Alias(command) => match command { - RoomAliasCommand::Set { - ref room_alias_localpart, - .. - } - | RoomAliasCommand::Remove { - ref room_alias_localpart, - } - | RoomAliasCommand::Which { - ref room_alias_localpart, - } => { - let room_alias_str = format!( - "#{}:{}", - room_alias_localpart, - services().globals.server_name() - ); - let room_alias = match RoomAliasId::parse_box(room_alias_str) { - Ok(alias) => alias, - Err(err) => { - return Ok(RoomMessageEventContent::text_plain(format!( - "Failed to parse alias: {}", - err - ))) - } - }; - - match command { - RoomAliasCommand::Set { force, room_id, .. } => { - match (force, services().rooms.alias.resolve_local_alias(&room_alias)) { - (true, Ok(Some(id))) => match services().rooms.alias.set_alias(&room_alias, &room_id) { - Ok(()) => RoomMessageEventContent::text_plain(format!("Successfully overwrote alias (formerly {})", id)), - Err(err) => RoomMessageEventContent::text_plain(format!("Failed to remove alias: {}", err)), - } - (false, Ok(Some(id))) => { - RoomMessageEventContent::text_plain(format!("Refusing to overwrite in use alias for {}, use -f or --force to overwrite", id)) - } - (_, Ok(None)) => match services().rooms.alias.set_alias(&room_alias, &room_id) { - Ok(()) => RoomMessageEventContent::text_plain("Successfully set alias"), - Err(err) => RoomMessageEventContent::text_plain(format!("Failed to remove alias: {}", err)), - } - (_, Err(err)) => RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {}", err)), - } - } - RoomAliasCommand::Remove { .. } => { - match services().rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => { - match services().rooms.alias.remove_alias(&room_alias) { - Ok(()) => RoomMessageEventContent::text_plain(format!( - "Removed alias from {}", - id - )), - Err(err) => RoomMessageEventContent::text_plain( - format!("Failed to remove alias: {}", err), - ), - } - } - Ok(None) => { - RoomMessageEventContent::text_plain("Alias isn't in use.") - } - Err(err) => RoomMessageEventContent::text_plain(format!( - "Unable to lookup alias: {}", - err - )), - } - } - RoomAliasCommand::Which { .. } => { - match services().rooms.alias.resolve_local_alias(&room_alias) { - Ok(Some(id)) => RoomMessageEventContent::text_plain(format!( - "Alias resolves to {}", - id - )), - Ok(None) => { - RoomMessageEventContent::text_plain("Alias isn't in use.") - } - Err(err) => RoomMessageEventContent::text_plain(format!( - "Unable to lookup alias: {}", - err - )), - } - } - RoomAliasCommand::List { .. } => unreachable!(), - } - } - RoomAliasCommand::List { room_id } => match room_id { - Some(room_id) => { - let aliases: Result, _> = services() - .rooms - .alias - .local_aliases_for_room(&room_id) - .collect(); - match aliases { - Ok(aliases) => { - let plain_list: String = - aliases.iter().fold(String::new(), |mut output, alias| { - writeln!(output, "- {}", alias).unwrap(); - output - }); - - let html_list: String = - aliases.iter().fold(String::new(), |mut output, alias| { - writeln!( - output, - "
  • {}
  • ", - escape_html(alias.as_ref()) - ) - .unwrap(); - output - }); - - let plain = format!("Aliases for {}:\n{}", room_id, plain_list); - let html = - format!("Aliases for {}:\n
      {}
    ", room_id, html_list); - RoomMessageEventContent::text_html(plain, html) - } - Err(err) => RoomMessageEventContent::text_plain(format!( - "Unable to list aliases: {}", - err - )), - } - } - None => { - let aliases: Result, _> = - services().rooms.alias.all_local_aliases().collect(); - match aliases { - Ok(aliases) => { - let server_name = services().globals.server_name(); - let plain_list: String = aliases.iter().fold( - String::new(), - |mut output, (alias, id)| { - writeln!( - output, - "- `{}` -> #{}:{}", - alias, id, server_name - ) - .unwrap(); - output - }, - ); - - let html_list: String = aliases.iter().fold( - String::new(), - |mut output, (alias, id)| { - writeln!( - output, - "
  • {} -> #{}:{}
  • ", - escape_html(alias.as_ref()), - escape_html(id.as_ref()), - server_name - ) - .unwrap(); - output - }, - ); - - let plain = format!("Aliases:\n{}", plain_list); - let html = format!("Aliases:\n
      {}
    ", html_list); - RoomMessageEventContent::text_html(plain, html) - } - Err(err) => RoomMessageEventContent::text_plain(format!( - "Unable to list room aliases: {}", - err - )), - } - } - }, - }, - RoomCommand::Directory(command) => match command { - RoomDirectoryCommand::Publish { room_id } => { - match services().rooms.directory.set_public(&room_id) { - Ok(()) => RoomMessageEventContent::text_plain("Room published"), - Err(err) => RoomMessageEventContent::text_plain(format!( - "Unable to update room: {}", - err - )), - } - } - RoomDirectoryCommand::Unpublish { room_id } => { - match services().rooms.directory.set_not_public(&room_id) { - Ok(()) => RoomMessageEventContent::text_plain("Room unpublished"), - Err(err) => RoomMessageEventContent::text_plain(format!( - "Unable to update room: {}", - err - )), - } - } - RoomDirectoryCommand::List { page } => { - // TODO: i know there's a way to do this with clap, but i can't seem to find it - let page = page.unwrap_or(1); - let mut rooms = services() - .rooms - .directory - .public_rooms() - .filter_map(std::result::Result::ok) - .map(Self::get_room_info) - .collect::>(); - rooms.sort_by_key(|r| r.1); - rooms.reverse(); - - let rooms: Vec<_> = rooms - .into_iter() - .skip(page.saturating_sub(1) * PAGE_SIZE) - .take(PAGE_SIZE) - .collect(); - - if rooms.is_empty() { - return Ok(RoomMessageEventContent::text_plain("No more rooms.")); - }; - - let output_plain = format!( - "Rooms:\n{}", - rooms - .iter() - .map(|(id, members, name)| format!( - "{id}\tMembers: {members}\tName: {name}" - )) - .collect::>() - .join("\n") - ); - let output_html = format!( - "\n\t\t\n{}
    Room directory - page {page}
    idmembersname
    ", - rooms - .iter() - .fold(String::new(), |mut output, (id, members, name)| { - writeln!(output, "{}\t{}\t{}", escape_html(id.as_ref()), members, escape_html(name.as_ref())).unwrap(); - output - }) - ); - RoomMessageEventContent::text_html(output_plain, output_html) - } - }, - }, - AdminCommand::Federation(command) => match command { - FederationCommand::DisableRoom { room_id } => { - services().rooms.metadata.disable_room(&room_id, true)?; - RoomMessageEventContent::text_plain("Room disabled.") - } - FederationCommand::EnableRoom { room_id } => { - services().rooms.metadata.disable_room(&room_id, false)?; - RoomMessageEventContent::text_plain("Room enabled.") - } - FederationCommand::IncomingFederation => { - let map = services() - .globals - .roomid_federationhandletime - .read() - .unwrap(); - let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); - - for (r, (e, i)) in map.iter() { - let elapsed = i.elapsed(); - let _ = writeln!( - msg, - "{} {}: {}m{}s", - r, - e, - elapsed.as_secs() / 60, - elapsed.as_secs() % 60 - ); - } - RoomMessageEventContent::text_plain(&msg) - } - FederationCommand::SignJson => { - if body.len() > 2 - && body[0].trim().starts_with("```") - && body.last().unwrap().trim() == "```" - { - let string = body[1..body.len() - 1].join("\n"); - match serde_json::from_str(&string) { - Ok(mut value) => { - ruma::signatures::sign_json( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut value, - ) - .expect("our request json is what ruma expects"); - let json_text = serde_json::to_string_pretty(&value) - .expect("canonical json is valid json"); - RoomMessageEventContent::text_plain(json_text) - } - Err(e) => { - RoomMessageEventContent::text_plain(format!("Invalid json: {e}")) - } - } - } else { - RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - ) - } - } - FederationCommand::VerifyJson => { - if body.len() > 2 - && body[0].trim().starts_with("```") - && body.last().unwrap().trim() == "```" - { - let string = body[1..body.len() - 1].join("\n"); - match serde_json::from_str(&string) { - Ok(value) => { - let pub_key_map = RwLock::new(BTreeMap::new()); - - services() - .rooms - .event_handler - .fetch_required_signing_keys([&value], &pub_key_map) - .await?; - - let pub_key_map = pub_key_map.read().unwrap(); - match ruma::signatures::verify_json(&pub_key_map, &value) { - Ok(_) => { - RoomMessageEventContent::text_plain("Signature correct") - } - Err(e) => RoomMessageEventContent::text_plain(format!( - "Signature verification failed: {e}" - )), - } - } - Err(e) => { - RoomMessageEventContent::text_plain(format!("Invalid json: {e}")) - } - } - } else { - RoomMessageEventContent::text_plain( - "Expected code block in command body. Add --help for details.", - ) - } - } - }, - AdminCommand::Server(command) => match command { - ServerCommand::ShowConfig => { - // Construct and send the response - RoomMessageEventContent::text_plain(format!("{}", services().globals.config)) - } - ServerCommand::MemoryUsage => { - let response1 = services().memory_usage(); - let response2 = services().globals.db.memory_usage(); - - RoomMessageEventContent::text_plain(format!( - "Services:\n{response1}\n\nDatabase:\n{response2}" - )) - } - ServerCommand::ClearDatabaseCaches { amount } => { - services().globals.db.clear_caches(amount); - - RoomMessageEventContent::text_plain("Done.") - } - ServerCommand::ClearServiceCaches { amount } => { - services().clear_caches(amount); - - RoomMessageEventContent::text_plain("Done.") - } - }, - AdminCommand::Debug(command) => match command { - DebugCommand::GetAuthChain { event_id } => { - let event_id = Arc::::from(event_id); - if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? { - let room_id_str = event - .get("room_id") - .and_then(|val| val.as_str()) - .ok_or_else(|| Error::bad_database("Invalid event in database"))?; - - let room_id = <&RoomId>::try_from(room_id_str).map_err(|_| { - Error::bad_database("Invalid room id field in event in database") - })?; - let start = Instant::now(); - let count = services() - .rooms - .auth_chain - .get_auth_chain(room_id, vec![event_id]) - .await? - .count(); - let elapsed = start.elapsed(); - RoomMessageEventContent::text_plain(format!( - "Loaded auth chain with length {count} in {elapsed:?}" - )) - } else { - RoomMessageEventContent::text_plain("Event not found.") - } - } - DebugCommand::ParsePdu => { - if body.len() > 2 - && body[0].trim().starts_with("```") - && body.last().unwrap().trim() == "```" - { - let string = body[1..body.len() - 1].join("\n"); - match serde_json::from_str(&string) { - Ok(value) => { - match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { - Ok(hash) => { - let event_id = EventId::parse(format!("${hash}")); - - match serde_json::from_value::( - serde_json::to_value(value).expect("value is json"), - ) { - Ok(pdu) => RoomMessageEventContent::text_plain( - format!("EventId: {event_id:?}\n{pdu:#?}"), - ), - Err(e) => RoomMessageEventContent::text_plain(format!( - "EventId: {event_id:?}\nCould not parse event: {e}" - )), - } - } - Err(e) => RoomMessageEventContent::text_plain(format!( - "Could not parse PDU JSON: {e:?}" - )), - } - } - Err(e) => RoomMessageEventContent::text_plain(format!( - "Invalid json in command body: {e}" - )), - } - } else { - RoomMessageEventContent::text_plain("Expected code block in command body.") - } - } - DebugCommand::GetPdu { event_id } => { - let mut outlier = false; - let mut pdu_json = services() - .rooms - .timeline - .get_non_outlier_pdu_json(&event_id)?; - if pdu_json.is_none() { - outlier = true; - pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?; - } - match pdu_json { - Some(json) => { - let json_text = serde_json::to_string_pretty(&json) - .expect("canonical json is valid json"); - RoomMessageEventContent::text_html( - format!( - "{}\n```json\n{}\n```", - if outlier { - "PDU is outlier" - } else { - "PDU was accepted" - }, - json_text - ), - format!( - "

    {}

    \n
    {}\n
    \n", - if outlier { - "PDU is outlier" - } else { - "PDU was accepted" - }, - HtmlEscape(&json_text) - ), - ) - } - None => RoomMessageEventContent::text_plain("PDU not found."), - } - } - DebugCommand::ForceDeviceListUpdates => { - // Force E2EE device list updates for all users - for user_id in services().users.iter().filter_map(std::result::Result::ok) { - services().users.mark_device_key_update(&user_id)?; - } - RoomMessageEventContent::text_plain( - "Marked all devices for all users as having new keys to update", - ) - } - }, - }; - - Ok(reply_message_content) - } - - fn get_room_info(id: OwnedRoomId) -> (OwnedRoomId, u64, String) { - ( - id.clone(), - services() - .rooms - .state_cache - .room_joined_count(&id) - .ok() - .flatten() - .unwrap_or(0), - services() - .rooms - .state_accessor - .get_name(&id) - .ok() - .flatten() - .unwrap_or_else(|| id.to_string()), - ) - } - - // Utility to turn clap's `--help` text to HTML. - fn usage_to_html(&self, text: &str, server_name: &ServerName) -> String { - // Replace `@conduit:servername:-subcmdname` with `@conduit:servername: subcmdname` - let text = text.replace( - &format!("@conduit:{server_name}:-"), - &format!("@conduit:{server_name}: "), - ); - - // For the conduit admin room, subcommands become main commands - let text = text.replace("SUBCOMMAND", "COMMAND"); - let text = text.replace("subcommand", "command"); - - // Escape option names (e.g. ``) since they look like HTML tags - let text = escape_html(&text); - - // Italicize the first line (command name and version text) - let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); - let text = re.replace_all(&text, "$1\n"); - - // Unmerge wrapped lines - let text = text.replace("\n ", " "); - - // Wrap option names in backticks. The lines look like: - // -V, --version Prints version information - // And are converted to: - // -V, --version: Prints version information - // (?m) enables multi-line mode for ^ and $ - let re = Regex::new("(?m)^ {4}(([a-zA-Z_&;-]+(, )?)+) +(.*)$") - .expect("Regex compilation should not fail"); - let text = re.replace_all(&text, "$1: $4"); - - // Look for a `[commandbody]` tag. If it exists, use all lines below it that - // start with a `#` in the USAGE section. - let mut text_lines: Vec<&str> = text.lines().collect(); - let mut command_body = String::new(); - - if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") { - text_lines.remove(line_index); - - while text_lines - .get(line_index) - .map(|line| line.starts_with('#')) - .unwrap_or(false) - { - command_body += if text_lines[line_index].starts_with("# ") { - &text_lines[line_index][2..] - } else { - &text_lines[line_index][1..] - }; - command_body += "[nobr]\n"; - text_lines.remove(line_index); - } - } - - let text = text_lines.join("\n"); - - // Improve the usage section - let text = if command_body.is_empty() { - // Wrap the usage line in code tags - let re = Regex::new("(?m)^USAGE:\n {4}(@conduit:.*)$") - .expect("Regex compilation should not fail"); - re.replace_all(&text, "USAGE:\n$1").to_string() - } else { - // Wrap the usage line in a code block, and add a yaml block example - // This makes the usage of e.g. `register-appservice` more accurate - let re = Regex::new("(?m)^USAGE:\n {4}(.*?)\n\n") - .expect("Regex compilation should not fail"); - re.replace_all(&text, "USAGE:\n
    $1[nobr]\n[commandbodyblock]
    ") - .replace("[commandbodyblock]", &command_body) - }; - - // Add HTML line-breaks - - text.replace("\n\n\n", "\n\n") - .replace('\n', "
    \n") - .replace("[nobr]
    ", "") - } - - /// Create the admin room. - /// - /// Users in this room are considered admins by conduit, and the room can be - /// used to issue admin commands by talking to the server user inside it. - pub(crate) async fn create_admin_room(&self) -> Result<()> { - let room_id = RoomId::new(services().globals.server_name()); - - services().rooms.short.get_or_create_shortroomid(&room_id)?; - - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Create a user for the server - let conduit_user = - UserId::parse_with_server_name("conduit", services().globals.server_name()) - .expect("@conduit:server_name is valid"); - - services().users.create(&conduit_user, None)?; - - let room_version = services().globals.default_room_version(); - let mut content = match room_version { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => RoomCreateEventContent::new_v1(conduit_user.clone()), - RoomVersionId::V11 => RoomCreateEventContent::new_v11(), - _ => { - warn!("Unexpected or unsupported room version {}", room_version); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - } - }; - - content.federate = true; - content.predecessor = None; - content.room_version = room_version; - - // 1. The room create event - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCreate, - content: to_raw_value(&content).expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - // 2. Make conduit bot join - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(conduit_user.to_string()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - // 3. Power levels - let mut users = BTreeMap::new(); - users.insert(conduit_user.clone(), 100.into()); - - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - // 4.1 Join Rules - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomJoinRules, - content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - // 4.2 History Visibility - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomHistoryVisibility, - content: to_raw_value(&RoomHistoryVisibilityEventContent::new( - HistoryVisibility::Shared, - )) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - // 4.3 Guest Access - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomGuestAccess, - content: to_raw_value(&RoomGuestAccessEventContent::new( - GuestAccess::Forbidden, - )) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - // 5. Events implied by name and topic - let room_name = format!("{} Admin Room", services().globals.server_name()); - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomName, - content: to_raw_value(&RoomNameEventContent::new(room_name)) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomTopic, - content: to_raw_value(&RoomTopicEventContent { - topic: format!("Manage {}", services().globals.server_name()), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - // 6. Room alias - let alias: OwnedRoomAliasId = format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomCanonicalAlias, - content: to_raw_value(&RoomCanonicalAliasEventContent { - alias: Some(alias.clone()), - alt_aliases: Vec::new(), - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - services().rooms.alias.set_alias(&alias, &room_id)?; - - Ok(()) - } - - /// Invite the user to the conduit admin room. - /// - /// In conduit, this is equivalent to granting admin privileges. - pub(crate) async fn make_user_admin( - &self, - user_id: &UserId, - displayname: String, - ) -> Result<()> { - let admin_room_alias: Box = - format!("#admins:{}", services().globals.server_name()) - .try_into() - .expect("#admins:server_name is a valid alias name"); - let room_id = services() - .rooms - .alias - .resolve_local_alias(&admin_room_alias)? - .expect("Admin room must exist"); - - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Use the server user to grant the new admin's power level - let conduit_user = - UserId::parse_with_server_name("conduit", services().globals.server_name()) - .expect("@conduit:server_name is valid"); - - // Invite and join the real user - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Invite, - displayname: None, - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomMember, - content: to_raw_value(&RoomMemberEventContent { - membership: MembershipState::Join, - displayname: Some(displayname), - avatar_url: None, - is_direct: None, - third_party_invite: None, - blurhash: None, - reason: None, - join_authorized_via_users_server: None, - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some(user_id.to_string()), - redacts: None, - }, - user_id, - &room_id, - &state_lock, - ) - .await?; - - // Set power level - let mut users = BTreeMap::new(); - users.insert(conduit_user.clone(), 100.into()); - users.insert(user_id.to_owned(), 100.into()); - - services() - .rooms - .timeline - .build_and_append_pdu( - PduBuilder { - event_type: TimelineEventType::RoomPowerLevels, - content: to_raw_value(&RoomPowerLevelsEventContent { - users, - ..Default::default() - }) - .expect("event is valid, we just created it"), - unsigned: None, - state_key: Some("".to_owned()), - redacts: None, - }, - &conduit_user, - &room_id, - &state_lock, - ) - .await?; - - // Send welcome message - services().rooms.timeline.build_and_append_pdu( + }) + }) + .collect::>() + { + debug!("Attempting leave for user {} in room {}", &local_user, &room_id); + if let Err(e) = leave_room(&local_user, room_id, None).await { + error!( + "Error attempting to make local user {} leave room {} during bulk \ + room banning: {}", + &local_user, &room_id, e + ); + return Ok(RoomMessageEventContent::text_plain(format!( + "Error attempting to make local user {} leave room {} during room \ + banning (room is still banned but not removing any more users \ + and not banning any more rooms): {}\nIf you would like to ignore \ + errors, use --force", + &local_user, &room_id, e + ))); + } + } + } + + if disable_federation { + services().rooms.metadata.disable_room(room_id, true)?; + } + } + + if disable_federation { + return Ok(RoomMessageEventContent::text_plain(format!( + "Finished bulk room ban, banned {} total rooms, evicted all users, and \ + disabled incoming federation with the room.", + room_ban_count + ))); + } else { + return Ok(RoomMessageEventContent::text_plain(format!( + "Finished bulk room ban, banned {} total rooms and evicted all users.", + room_ban_count + ))); + } + } else { + return Ok(RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + )); + } + }, + RoomModeration::UnbanRoom { + room, + enable_federation, + } => { + let room_id = if room.is_room_id() { + let room_id = match RoomId::parse(&room) { + Ok(room_id) => room_id, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to parse room ID {room}. Please note that this requires a full \ + room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias \ + (`#roomalias:example.com`): {e}" + ))) + }, + }; + + debug!("Room specified is a room ID, unbanning room ID"); + + services().rooms.metadata.ban_room(&room_id, false)?; + + room_id + } else if room.is_room_alias_id() { + let room_alias = match RoomAliasId::parse(&room) { + Ok(room_alias) => room_alias, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to parse room ID {room}. Please note that this requires a full \ + room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias \ + (`#roomalias:example.com`): {e}" + ))) + }, + }; + + debug!( + "Room specified is not a room ID, attempting to resolve room alias to a room ID \ + locally, if not using get_alias_helper to fetch room ID remotely" + ); + + let room_id = match services().rooms.alias.resolve_local_alias(&room_alias)? { + Some(room_id) => room_id, + None => { + debug!( + "We don't have this room alias to a room ID locally, attempting to fetch \ + room ID over federation" + ); + + match get_alias_helper(room_alias).await { + Ok(response) => { + debug!( + "Got federation response fetching room ID for room {room}: {:?}", + response + ); + response.room_id + }, + Err(e) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to resolve room alias {room} to a room ID: {e}" + ))); + }, + } + }, + }; + + services().rooms.metadata.ban_room(&room_id, false)?; + + room_id + } else { + return Ok(RoomMessageEventContent::text_plain( + "Room specified is not a room ID or room alias. Please note that this requires a \ + full room ID (`!awIh6gGInaS5wLQJwa:example.com`) or a room alias \ + (`#roomalias:example.com`)", + )); + }; + + if enable_federation { + services().rooms.metadata.disable_room(&room_id, false)?; + return Ok(RoomMessageEventContent::text_plain("Room unbanned.")); + } + + RoomMessageEventContent::text_plain( + "Room unbanned, you may need to re-enable federation with the room using enable-room \ + if this is a remote room to make it fully functional.", + ) + }, + RoomModeration::ListBannedRooms => { + let rooms: Result, _> = services().rooms.metadata.list_banned_rooms().collect(); + + match rooms { + Ok(room_ids) => { + // TODO: add room name from our state cache if available, default to the room ID + // as the room name if we dont have it TODO: do same if we have a room alias for + // this + let plain_list = room_ids.iter().fold(String::new(), |mut output, room_id| { + writeln!(output, "- `{}`", room_id).unwrap(); + output + }); + + let html_list = room_ids.iter().fold(String::new(), |mut output, room_id| { + writeln!(output, "
  • {}
  • ", escape_html(room_id.as_ref())) + .unwrap(); + output + }); + + let plain = format!("Rooms:\n{}", plain_list); + let html = format!("Rooms:\n
      {}
    ", html_list); + RoomMessageEventContent::text_html(plain, html) + }, + Err(e) => { + error!("Failed to list banned rooms: {}", e); + RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {}", e)) + }, + } + }, + } + }, + RoomCommand::List { + page, + } => { + // TODO: i know there's a way to do this with clap, but i can't seem to find it + let page = page.unwrap_or(1); + let mut rooms = services() + .rooms + .metadata + .iter_ids() + .filter_map(std::result::Result::ok) + .map(Self::get_room_info) + .collect::>(); + rooms.sort_by_key(|r| r.1); + rooms.reverse(); + + let rooms: Vec<_> = + rooms.into_iter().skip(page.saturating_sub(1) * PAGE_SIZE).take(PAGE_SIZE).collect(); + + if rooms.is_empty() { + return Ok(RoomMessageEventContent::text_plain("No more rooms.")); + }; + + let output_plain = format!( + "Rooms:\n{}", + rooms + .iter() + .map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}")) + .collect::>() + .join("\n") + ); + let output_html = format!( + "\n\t\t\n{}
    Room list - page \ + {page}
    idmembersname
    ", + rooms.iter().fold(String::new(), |mut output, (id, members, name)| { + writeln!( + output, + "{}\t{}\t{}", + escape_html(id.as_ref()), + members, + escape_html(name) + ) + .unwrap(); + output + }) + ); + RoomMessageEventContent::text_html(output_plain, output_html) + }, + RoomCommand::Alias(command) => match command { + RoomAliasCommand::Set { + ref room_alias_localpart, + .. + } + | RoomAliasCommand::Remove { + ref room_alias_localpart, + } + | RoomAliasCommand::Which { + ref room_alias_localpart, + } => { + let room_alias_str = format!("#{}:{}", room_alias_localpart, services().globals.server_name()); + let room_alias = match RoomAliasId::parse_box(room_alias_str) { + Ok(alias) => alias, + Err(err) => { + return Ok(RoomMessageEventContent::text_plain(format!( + "Failed to parse alias: {}", + err + ))) + }, + }; + + match command { + RoomAliasCommand::Set { + force, + room_id, + .. + } => match (force, services().rooms.alias.resolve_local_alias(&room_alias)) { + (true, Ok(Some(id))) => match services().rooms.alias.set_alias(&room_alias, &room_id) { + Ok(()) => RoomMessageEventContent::text_plain(format!( + "Successfully overwrote alias (formerly {})", + id + )), + Err(err) => { + RoomMessageEventContent::text_plain(format!("Failed to remove alias: {}", err)) + }, + }, + (false, Ok(Some(id))) => RoomMessageEventContent::text_plain(format!( + "Refusing to overwrite in use alias for {}, use -f or --force to overwrite", + id + )), + (_, Ok(None)) => match services().rooms.alias.set_alias(&room_alias, &room_id) { + Ok(()) => RoomMessageEventContent::text_plain("Successfully set alias"), + Err(err) => { + RoomMessageEventContent::text_plain(format!("Failed to remove alias: {}", err)) + }, + }, + (_, Err(err)) => { + RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {}", err)) + }, + }, + RoomAliasCommand::Remove { + .. + } => match services().rooms.alias.resolve_local_alias(&room_alias) { + Ok(Some(id)) => match services().rooms.alias.remove_alias(&room_alias) { + Ok(()) => RoomMessageEventContent::text_plain(format!("Removed alias from {}", id)), + Err(err) => { + RoomMessageEventContent::text_plain(format!("Failed to remove alias: {}", err)) + }, + }, + Ok(None) => RoomMessageEventContent::text_plain("Alias isn't in use."), + Err(err) => { + RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {}", err)) + }, + }, + RoomAliasCommand::Which { + .. + } => match services().rooms.alias.resolve_local_alias(&room_alias) { + Ok(Some(id)) => { + RoomMessageEventContent::text_plain(format!("Alias resolves to {}", id)) + }, + Ok(None) => RoomMessageEventContent::text_plain("Alias isn't in use."), + Err(err) => { + RoomMessageEventContent::text_plain(format!("Unable to lookup alias: {}", err)) + }, + }, + RoomAliasCommand::List { + .. + } => unreachable!(), + } + }, + RoomAliasCommand::List { + room_id, + } => match room_id { + Some(room_id) => { + let aliases: Result, _> = + services().rooms.alias.local_aliases_for_room(&room_id).collect(); + match aliases { + Ok(aliases) => { + let plain_list: String = aliases.iter().fold(String::new(), |mut output, alias| { + writeln!(output, "- {}", alias).unwrap(); + output + }); + + let html_list: String = aliases.iter().fold(String::new(), |mut output, alias| { + writeln!(output, "
  • {}
  • ", escape_html(alias.as_ref())).unwrap(); + output + }); + + let plain = format!("Aliases for {}:\n{}", room_id, plain_list); + let html = format!("Aliases for {}:\n
      {}
    ", room_id, html_list); + RoomMessageEventContent::text_html(plain, html) + }, + Err(err) => { + RoomMessageEventContent::text_plain(format!("Unable to list aliases: {}", err)) + }, + } + }, + None => { + let aliases: Result, _> = services().rooms.alias.all_local_aliases().collect(); + match aliases { + Ok(aliases) => { + let server_name = services().globals.server_name(); + let plain_list: String = + aliases.iter().fold(String::new(), |mut output, (alias, id)| { + writeln!(output, "- `{}` -> #{}:{}", alias, id, server_name).unwrap(); + output + }); + + let html_list: String = + aliases.iter().fold(String::new(), |mut output, (alias, id)| { + writeln!( + output, + "
  • {} -> #{}:{}
  • ", + escape_html(alias.as_ref()), + escape_html(id.as_ref()), + server_name + ) + .unwrap(); + output + }); + + let plain = format!("Aliases:\n{}", plain_list); + let html = format!("Aliases:\n
      {}
    ", html_list); + RoomMessageEventContent::text_html(plain, html) + }, + Err(err) => { + RoomMessageEventContent::text_plain(format!("Unable to list room aliases: {}", err)) + }, + } + }, + }, + }, + RoomCommand::Directory(command) => match command { + RoomDirectoryCommand::Publish { + room_id, + } => match services().rooms.directory.set_public(&room_id) { + Ok(()) => RoomMessageEventContent::text_plain("Room published"), + Err(err) => RoomMessageEventContent::text_plain(format!("Unable to update room: {}", err)), + }, + RoomDirectoryCommand::Unpublish { + room_id, + } => match services().rooms.directory.set_not_public(&room_id) { + Ok(()) => RoomMessageEventContent::text_plain("Room unpublished"), + Err(err) => RoomMessageEventContent::text_plain(format!("Unable to update room: {}", err)), + }, + RoomDirectoryCommand::List { + page, + } => { + // TODO: i know there's a way to do this with clap, but i can't seem to find it + let page = page.unwrap_or(1); + let mut rooms = services() + .rooms + .directory + .public_rooms() + .filter_map(std::result::Result::ok) + .map(Self::get_room_info) + .collect::>(); + rooms.sort_by_key(|r| r.1); + rooms.reverse(); + + let rooms: Vec<_> = + rooms.into_iter().skip(page.saturating_sub(1) * PAGE_SIZE).take(PAGE_SIZE).collect(); + + if rooms.is_empty() { + return Ok(RoomMessageEventContent::text_plain("No more rooms.")); + }; + + let output_plain = format!( + "Rooms:\n{}", + rooms + .iter() + .map(|(id, members, name)| format!("{id}\tMembers: {members}\tName: {name}")) + .collect::>() + .join("\n") + ); + let output_html = format!( + "\n\t\t\n{}
    Room directory - page \ + {page}
    idmembersname
    ", + rooms.iter().fold(String::new(), |mut output, (id, members, name)| { + writeln!( + output, + "{}\t{}\t{}", + escape_html(id.as_ref()), + members, + escape_html(name.as_ref()) + ) + .unwrap(); + output + }) + ); + RoomMessageEventContent::text_html(output_plain, output_html) + }, + }, + }, + AdminCommand::Federation(command) => match command { + FederationCommand::DisableRoom { + room_id, + } => { + services().rooms.metadata.disable_room(&room_id, true)?; + RoomMessageEventContent::text_plain("Room disabled.") + }, + FederationCommand::EnableRoom { + room_id, + } => { + services().rooms.metadata.disable_room(&room_id, false)?; + RoomMessageEventContent::text_plain("Room enabled.") + }, + FederationCommand::IncomingFederation => { + let map = services().globals.roomid_federationhandletime.read().unwrap(); + let mut msg: String = format!("Handling {} incoming pdus:\n", map.len()); + + for (r, (e, i)) in map.iter() { + let elapsed = i.elapsed(); + let _ = writeln!(msg, "{} {}: {}m{}s", r, e, elapsed.as_secs() / 60, elapsed.as_secs() % 60); + } + RoomMessageEventContent::text_plain(&msg) + }, + FederationCommand::SignJson => { + if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" { + let string = body[1..body.len() - 1].join("\n"); + match serde_json::from_str(&string) { + Ok(mut value) => { + ruma::signatures::sign_json( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut value, + ) + .expect("our request json is what ruma expects"); + let json_text = + serde_json::to_string_pretty(&value).expect("canonical json is valid json"); + RoomMessageEventContent::text_plain(json_text) + }, + Err(e) => RoomMessageEventContent::text_plain(format!("Invalid json: {e}")), + } + } else { + RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + ) + } + }, + FederationCommand::VerifyJson => { + if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" { + let string = body[1..body.len() - 1].join("\n"); + match serde_json::from_str(&string) { + Ok(value) => { + let pub_key_map = RwLock::new(BTreeMap::new()); + + services() + .rooms + .event_handler + .fetch_required_signing_keys([&value], &pub_key_map) + .await?; + + let pub_key_map = pub_key_map.read().unwrap(); + match ruma::signatures::verify_json(&pub_key_map, &value) { + Ok(_) => RoomMessageEventContent::text_plain("Signature correct"), + Err(e) => RoomMessageEventContent::text_plain(format!( + "Signature verification failed: {e}" + )), + } + }, + Err(e) => RoomMessageEventContent::text_plain(format!("Invalid json: {e}")), + } + } else { + RoomMessageEventContent::text_plain( + "Expected code block in command body. Add --help for details.", + ) + } + }, + }, + AdminCommand::Server(command) => match command { + ServerCommand::ShowConfig => { + // Construct and send the response + RoomMessageEventContent::text_plain(format!("{}", services().globals.config)) + }, + ServerCommand::MemoryUsage => { + let response1 = services().memory_usage(); + let response2 = services().globals.db.memory_usage(); + + RoomMessageEventContent::text_plain(format!("Services:\n{response1}\n\nDatabase:\n{response2}")) + }, + ServerCommand::ClearDatabaseCaches { + amount, + } => { + services().globals.db.clear_caches(amount); + + RoomMessageEventContent::text_plain("Done.") + }, + ServerCommand::ClearServiceCaches { + amount, + } => { + services().clear_caches(amount); + + RoomMessageEventContent::text_plain("Done.") + }, + }, + AdminCommand::Debug(command) => match command { + DebugCommand::GetAuthChain { + event_id, + } => { + let event_id = Arc::::from(event_id); + if let Some(event) = services().rooms.timeline.get_pdu_json(&event_id)? { + let room_id_str = event + .get("room_id") + .and_then(|val| val.as_str()) + .ok_or_else(|| Error::bad_database("Invalid event in database"))?; + + let room_id = <&RoomId>::try_from(room_id_str) + .map_err(|_| Error::bad_database("Invalid room id field in event in database"))?; + let start = Instant::now(); + let count = services().rooms.auth_chain.get_auth_chain(room_id, vec![event_id]).await?.count(); + let elapsed = start.elapsed(); + RoomMessageEventContent::text_plain(format!( + "Loaded auth chain with length {count} in {elapsed:?}" + )) + } else { + RoomMessageEventContent::text_plain("Event not found.") + } + }, + DebugCommand::ParsePdu => { + if body.len() > 2 && body[0].trim().starts_with("```") && body.last().unwrap().trim() == "```" { + let string = body[1..body.len() - 1].join("\n"); + match serde_json::from_str(&string) { + Ok(value) => match ruma::signatures::reference_hash(&value, &RoomVersionId::V6) { + Ok(hash) => { + let event_id = EventId::parse(format!("${hash}")); + + match serde_json::from_value::( + serde_json::to_value(value).expect("value is json"), + ) { + Ok(pdu) => RoomMessageEventContent::text_plain(format!( + "EventId: {event_id:?}\n{pdu:#?}" + )), + Err(e) => RoomMessageEventContent::text_plain(format!( + "EventId: {event_id:?}\nCould not parse event: {e}" + )), + } + }, + Err(e) => { + RoomMessageEventContent::text_plain(format!("Could not parse PDU JSON: {e:?}")) + }, + }, + Err(e) => RoomMessageEventContent::text_plain(format!("Invalid json in command body: {e}")), + } + } else { + RoomMessageEventContent::text_plain("Expected code block in command body.") + } + }, + DebugCommand::GetPdu { + event_id, + } => { + let mut outlier = false; + let mut pdu_json = services().rooms.timeline.get_non_outlier_pdu_json(&event_id)?; + if pdu_json.is_none() { + outlier = true; + pdu_json = services().rooms.timeline.get_pdu_json(&event_id)?; + } + match pdu_json { + Some(json) => { + let json_text = serde_json::to_string_pretty(&json).expect("canonical json is valid json"); + RoomMessageEventContent::text_html( + format!( + "{}\n```json\n{}\n```", + if outlier { + "PDU is outlier" + } else { + "PDU was accepted" + }, + json_text + ), + format!( + "

    {}

    \n
    {}\n
    \n", + if outlier { + "PDU is outlier" + } else { + "PDU was accepted" + }, + HtmlEscape(&json_text) + ), + ) + }, + None => RoomMessageEventContent::text_plain("PDU not found."), + } + }, + DebugCommand::ForceDeviceListUpdates => { + // Force E2EE device list updates for all users + for user_id in services().users.iter().filter_map(std::result::Result::ok) { + services().users.mark_device_key_update(&user_id)?; + } + RoomMessageEventContent::text_plain("Marked all devices for all users as having new keys to update") + }, + }, + }; + + Ok(reply_message_content) + } + + fn get_room_info(id: OwnedRoomId) -> (OwnedRoomId, u64, String) { + ( + id.clone(), + services().rooms.state_cache.room_joined_count(&id).ok().flatten().unwrap_or(0), + services().rooms.state_accessor.get_name(&id).ok().flatten().unwrap_or_else(|| id.to_string()), + ) + } + + // Utility to turn clap's `--help` text to HTML. + fn usage_to_html(&self, text: &str, server_name: &ServerName) -> String { + // Replace `@conduit:servername:-subcmdname` with `@conduit:servername: + // subcmdname` + let text = text.replace(&format!("@conduit:{server_name}:-"), &format!("@conduit:{server_name}: ")); + + // For the conduit admin room, subcommands become main commands + let text = text.replace("SUBCOMMAND", "COMMAND"); + let text = text.replace("subcommand", "command"); + + // Escape option names (e.g. ``) since they look like HTML tags + let text = escape_html(&text); + + // Italicize the first line (command name and version text) + let re = Regex::new("^(.*?)\n").expect("Regex compilation should not fail"); + let text = re.replace_all(&text, "$1\n"); + + // Unmerge wrapped lines + let text = text.replace("\n ", " "); + + // Wrap option names in backticks. The lines look like: + // -V, --version Prints version information + // And are converted to: + // -V, --version: Prints version information + // (?m) enables multi-line mode for ^ and $ + let re = Regex::new("(?m)^ {4}(([a-zA-Z_&;-]+(, )?)+) +(.*)$").expect("Regex compilation should not fail"); + let text = re.replace_all(&text, "$1: $4"); + + // Look for a `[commandbody]` tag. If it exists, use all lines below it that + // start with a `#` in the USAGE section. + let mut text_lines: Vec<&str> = text.lines().collect(); + let mut command_body = String::new(); + + if let Some(line_index) = text_lines.iter().position(|line| *line == "[commandbody]") { + text_lines.remove(line_index); + + while text_lines.get(line_index).map(|line| line.starts_with('#')).unwrap_or(false) { + command_body += if text_lines[line_index].starts_with("# ") { + &text_lines[line_index][2..] + } else { + &text_lines[line_index][1..] + }; + command_body += "[nobr]\n"; + text_lines.remove(line_index); + } + } + + let text = text_lines.join("\n"); + + // Improve the usage section + let text = if command_body.is_empty() { + // Wrap the usage line in code tags + let re = Regex::new("(?m)^USAGE:\n {4}(@conduit:.*)$").expect("Regex compilation should not fail"); + re.replace_all(&text, "USAGE:\n$1").to_string() + } else { + // Wrap the usage line in a code block, and add a yaml block example + // This makes the usage of e.g. `register-appservice` more accurate + let re = Regex::new("(?m)^USAGE:\n {4}(.*?)\n\n").expect("Regex compilation should not fail"); + re.replace_all(&text, "USAGE:\n
    $1[nobr]\n[commandbodyblock]
    ") + .replace("[commandbodyblock]", &command_body) + }; + + // Add HTML line-breaks + + text.replace("\n\n\n", "\n\n").replace('\n', "
    \n").replace("[nobr]
    ", "") + } + + /// Create the admin room. + /// + /// Users in this room are considered admins by conduit, and the room can be + /// used to issue admin commands by talking to the server user inside it. + pub(crate) async fn create_admin_room(&self) -> Result<()> { + let room_id = RoomId::new(services().globals.server_name()); + + services().rooms.short.get_or_create_shortroomid(&room_id)?; + + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; + + // Create a user for the server + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is valid"); + + services().users.create(&conduit_user, None)?; + + let room_version = services().globals.default_room_version(); + let mut content = match room_version { + RoomVersionId::V1 + | RoomVersionId::V2 + | RoomVersionId::V3 + | RoomVersionId::V4 + | RoomVersionId::V5 + | RoomVersionId::V6 + | RoomVersionId::V7 + | RoomVersionId::V8 + | RoomVersionId::V9 + | RoomVersionId::V10 => RoomCreateEventContent::new_v1(conduit_user.clone()), + RoomVersionId::V11 => RoomCreateEventContent::new_v11(), + _ => { + warn!("Unexpected or unsupported room version {}", room_version); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + }; + + content.federate = true; + content.predecessor = None; + content.room_version = room_version; + + // 1. The room create event + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomCreate, + content: to_raw_value(&content).expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + + // 2. Make conduit bot join + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(conduit_user.to_string()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + + // 3. Power levels + let mut users = BTreeMap::new(); + users.insert(conduit_user.clone(), 100.into()); + + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomPowerLevels, + content: to_raw_value(&RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + + // 4.1 Join Rules + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomJoinRules, + content: to_raw_value(&RoomJoinRulesEventContent::new(JoinRule::Invite)) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + + // 4.2 History Visibility + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomHistoryVisibility, + content: to_raw_value(&RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared)) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + + // 4.3 Guest Access + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomGuestAccess, + content: to_raw_value(&RoomGuestAccessEventContent::new(GuestAccess::Forbidden)) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + + // 5. Events implied by name and topic + let room_name = format!("{} Admin Room", services().globals.server_name()); + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomName, + content: to_raw_value(&RoomNameEventContent::new(room_name)) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomTopic, + content: to_raw_value(&RoomTopicEventContent { + topic: format!("Manage {}", services().globals.server_name()), + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + + // 6. Room alias + let alias: OwnedRoomAliasId = format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); + + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomCanonicalAlias, + content: to_raw_value(&RoomCanonicalAliasEventContent { + alias: Some(alias.clone()), + alt_aliases: Vec::new(), + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + + services().rooms.alias.set_alias(&alias, &room_id)?; + + Ok(()) + } + + /// Invite the user to the conduit admin room. + /// + /// In conduit, this is equivalent to granting admin privileges. + pub(crate) async fn make_user_admin(&self, user_id: &UserId, displayname: String) -> Result<()> { + let admin_room_alias: Box = format!("#admins:{}", services().globals.server_name()) + .try_into() + .expect("#admins:server_name is a valid alias name"); + let room_id = services().rooms.alias.resolve_local_alias(&admin_room_alias)?.expect("Admin room must exist"); + + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default()); + let state_lock = mutex_state.lock().await; + + // Use the server user to grant the new admin's power level + let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name()) + .expect("@conduit:server_name is valid"); + + // Invite and join the real user + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Invite, + displayname: None, + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomMember, + content: to_raw_value(&RoomMemberEventContent { + membership: MembershipState::Join, + displayname: Some(displayname), + avatar_url: None, + is_direct: None, + third_party_invite: None, + blurhash: None, + reason: None, + join_authorized_via_users_server: None, + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some(user_id.to_string()), + redacts: None, + }, + user_id, + &room_id, + &state_lock, + ) + .await?; + + // Set power level + let mut users = BTreeMap::new(); + users.insert(conduit_user.clone(), 100.into()); + users.insert(user_id.to_owned(), 100.into()); + + services() + .rooms + .timeline + .build_and_append_pdu( + PduBuilder { + event_type: TimelineEventType::RoomPowerLevels, + content: to_raw_value(&RoomPowerLevelsEventContent { + users, + ..Default::default() + }) + .expect("event is valid, we just created it"), + unsigned: None, + state_key: Some("".to_owned()), + redacts: None, + }, + &conduit_user, + &room_id, + &state_lock, + ) + .await?; + + // Send welcome message + services().rooms.timeline.build_and_append_pdu( PduBuilder { event_type: TimelineEventType::RoomMessage, content: to_raw_value(&RoomMessageEventContent::text_html( @@ -2432,43 +2396,31 @@ impl Service { &state_lock, ).await?; - Ok(()) - } + Ok(()) + } } -fn escape_html(s: &str) -> String { - s.replace('&', "&") - .replace('<', "<") - .replace('>', ">") -} +fn escape_html(s: &str) -> String { s.replace('&', "&").replace('<', "<").replace('>', ">") } #[cfg(test)] mod test { - use super::*; + use super::*; - #[test] - fn get_help_short() { - get_help_inner("-h"); - } + #[test] + fn get_help_short() { get_help_inner("-h"); } - #[test] - fn get_help_long() { - get_help_inner("--help"); - } + #[test] + fn get_help_long() { get_help_inner("--help"); } - #[test] - fn get_help_subcommand() { - get_help_inner("help"); - } + #[test] + fn get_help_subcommand() { get_help_inner("help"); } - fn get_help_inner(input: &str) { - let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]) - .unwrap_err() - .to_string(); + fn get_help_inner(input: &str) { + let error = AdminCommand::try_parse_from(["argv[0] doesn't matter", input]).unwrap_err().to_string(); - // Search for a handful of keywords that suggest the help printed properly - assert!(error.contains("Usage:")); - assert!(error.contains("Commands:")); - assert!(error.contains("Options:")); - } + // Search for a handful of keywords that suggest the help printed properly + assert!(error.contains("Usage:")); + assert!(error.contains("Commands:")); + assert!(error.contains("Options:")); + } } diff --git a/src/service/appservice/data.rs b/src/service/appservice/data.rs index ab19a50c..52c8b34d 100644 --- a/src/service/appservice/data.rs +++ b/src/service/appservice/data.rs @@ -3,19 +3,19 @@ use ruma::api::appservice::Registration; use crate::Result; pub trait Data: Send + Sync { - /// Registers an appservice and returns the ID to the caller - fn register_appservice(&self, yaml: Registration) -> Result; + /// Registers an appservice and returns the ID to the caller + fn register_appservice(&self, yaml: Registration) -> Result; - /// Remove an appservice registration - /// - /// # Arguments - /// - /// * `service_name` - the name you send to register the service previously - fn unregister_appservice(&self, service_name: &str) -> Result<()>; + /// Remove an appservice registration + /// + /// # Arguments + /// + /// * `service_name` - the name you send to register the service previously + fn unregister_appservice(&self, service_name: &str) -> Result<()>; - fn get_registration(&self, id: &str) -> Result>; + fn get_registration(&self, id: &str) -> Result>; - fn iter_ids<'a>(&'a self) -> Result> + 'a>>; + fn iter_ids<'a>(&'a self) -> Result> + 'a>>; - fn all(&self) -> Result>; + fn all(&self) -> Result>; } diff --git a/src/service/appservice/mod.rs b/src/service/appservice/mod.rs index ed6f37bf..5700731d 100644 --- a/src/service/appservice/mod.rs +++ b/src/service/appservice/mod.rs @@ -6,33 +6,25 @@ use ruma::api::appservice::Registration; use crate::Result; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - /// Registers an appservice and returns the ID to the caller - pub fn register_appservice(&self, yaml: Registration) -> Result { - self.db.register_appservice(yaml) - } + /// Registers an appservice and returns the ID to the caller + pub fn register_appservice(&self, yaml: Registration) -> Result { self.db.register_appservice(yaml) } - /// Remove an appservice registration - /// - /// # Arguments - /// - /// * `service_name` - the name you send to register the service previously - pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { - self.db.unregister_appservice(service_name) - } + /// Remove an appservice registration + /// + /// # Arguments + /// + /// * `service_name` - the name you send to register the service previously + pub fn unregister_appservice(&self, service_name: &str) -> Result<()> { + self.db.unregister_appservice(service_name) + } - pub fn get_registration(&self, id: &str) -> Result> { - self.db.get_registration(id) - } + pub fn get_registration(&self, id: &str) -> Result> { self.db.get_registration(id) } - pub fn iter_ids(&self) -> Result> + '_> { - self.db.iter_ids() - } + pub fn iter_ids(&self) -> Result> + '_> { self.db.iter_ids() } - pub fn all(&self) -> Result> { - self.db.all() - } + pub fn all(&self) -> Result> { self.db.all() } } diff --git a/src/service/globals/data.rs b/src/service/globals/data.rs index 8a66751b..1ae76d21 100644 --- a/src/service/globals/data.rs +++ b/src/service/globals/data.rs @@ -2,36 +2,32 @@ use std::collections::BTreeMap; use async_trait::async_trait; use ruma::{ - api::federation::discovery::{ServerSigningKeys, VerifyKey}, - signatures::Ed25519KeyPair, - DeviceId, OwnedServerSigningKeyId, ServerName, UserId, + api::federation::discovery::{ServerSigningKeys, VerifyKey}, + signatures::Ed25519KeyPair, + DeviceId, OwnedServerSigningKeyId, ServerName, UserId, }; use crate::Result; #[async_trait] pub trait Data: Send + Sync { - fn next_count(&self) -> Result; - fn current_count(&self) -> Result; - fn last_check_for_updates_id(&self) -> Result; - fn update_check_for_updates_id(&self, id: u64) -> Result<()>; - async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; - fn cleanup(&self) -> Result<()>; - fn memory_usage(&self) -> String; - fn clear_caches(&self, amount: u32); - fn load_keypair(&self) -> Result; - fn remove_keypair(&self) -> Result<()>; - fn add_signing_key( - &self, - origin: &ServerName, - new_keys: ServerSigningKeys, - ) -> Result>; + fn next_count(&self) -> Result; + fn current_count(&self) -> Result; + fn last_check_for_updates_id(&self) -> Result; + fn update_check_for_updates_id(&self, id: u64) -> Result<()>; + async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + fn cleanup(&self) -> Result<()>; + fn memory_usage(&self) -> String; + fn clear_caches(&self, amount: u32); + fn load_keypair(&self) -> Result; + fn remove_keypair(&self) -> Result<()>; + fn add_signing_key( + &self, origin: &ServerName, new_keys: ServerSigningKeys, + ) -> Result>; - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. - fn signing_keys_for( - &self, - origin: &ServerName, - ) -> Result>; - fn database_version(&self) -> Result; - fn bump_database_version(&self, new_version: u64) -> Result<()>; + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found + /// for the server. + fn signing_keys_for(&self, origin: &ServerName) -> Result>; + fn database_version(&self) -> Result; + fn bump_database_version(&self, new_version: u64) -> Result<()>; } diff --git a/src/service/globals/mod.rs b/src/service/globals/mod.rs index 26898fd1..1f88cbc3 100644 --- a/src/service/globals/mod.rs +++ b/src/service/globals/mod.rs @@ -1,47 +1,43 @@ use std::{ - collections::{BTreeMap, HashMap}, - error::Error as StdError, - fs, - future::{self, Future}, - iter, - net::{IpAddr, SocketAddr}, - path::PathBuf, - sync::{ - atomic::{self, AtomicBool}, - Arc, Mutex, RwLock, - }, - time::{Duration, Instant}, + collections::{BTreeMap, HashMap}, + error::Error as StdError, + fs, + future::{self, Future}, + iter, + net::{IpAddr, SocketAddr}, + path::PathBuf, + sync::{ + atomic::{self, AtomicBool}, + Arc, Mutex, RwLock, + }, + time::{Duration, Instant}, }; use argon2::Argon2; use base64::{engine::general_purpose, Engine as _}; +pub use data::Data; use futures_util::FutureExt; use hyper::{ - client::connect::dns::{GaiResolver, Name}, - service::Service as HyperService, + client::connect::dns::{GaiResolver, Name}, + service::Service as HyperService, }; use regex::RegexSet; use reqwest::dns::{Addrs, Resolve, Resolving}; use ruma::{ - api::{ - client::sync::sync_events, - federation::discovery::{ServerSigningKeys, VerifyKey}, - }, - DeviceId, RoomVersionId, ServerName, UserId, -}; -use ruma::{ - serde::Base64, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName, - OwnedServerSigningKeyId, OwnedUserId, + api::{ + client::sync::sync_events, + federation::discovery::{ServerSigningKeys, VerifyKey}, + }, + serde::Base64, + DeviceId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId, + RoomVersionId, ServerName, UserId, }; use sha2::Digest; use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore}; use tracing::{error, info}; use trust_dns_resolver::TokioAsyncResolver; -pub use data::Data; - -use crate::api::server_server::FedDest; -use crate::{services, Config, Error, Result}; +use crate::{api::server_server::FedDest, services, Config, Error, Result}; mod data; @@ -49,608 +45,488 @@ type WellKnownMap = HashMap; type TlsNameMap = HashMap, u16)>; type RateLimitState = (Instant, u32); // Time if last failed try, number of failed tries type SyncHandle = ( - Option, // since - Receiver>>, // rx + Option, // since + Receiver>>, // rx ); pub struct Service<'a> { - pub db: &'static dyn Data, + pub db: &'static dyn Data, - pub actual_destination_cache: Arc>, // actual_destination, host - pub tls_name_override: Arc>, - pub config: Config, - keypair: Arc, - dns_resolver: TokioAsyncResolver, - jwt_decoding_key: Option, - url_preview_client: reqwest::Client, - federation_client: reqwest::Client, - default_client: reqwest::Client, - pub stable_room_versions: Vec, - pub unstable_room_versions: Vec, - pub bad_event_ratelimiter: Arc>>, - pub bad_signature_ratelimiter: Arc, RateLimitState>>>, - pub bad_query_ratelimiter: Arc>>, - pub servername_ratelimiter: Arc>>>, - pub sync_receivers: RwLock>, - pub roomid_mutex_insert: RwLock>>>, - pub roomid_mutex_state: RwLock>>>, - pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer - pub roomid_federationhandletime: RwLock>, - pub stateres_mutex: Arc>, - pub(crate) rotate: RotationHandler, + pub actual_destination_cache: Arc>, // actual_destination, host + pub tls_name_override: Arc>, + pub config: Config, + keypair: Arc, + dns_resolver: TokioAsyncResolver, + jwt_decoding_key: Option, + url_preview_client: reqwest::Client, + federation_client: reqwest::Client, + default_client: reqwest::Client, + pub stable_room_versions: Vec, + pub unstable_room_versions: Vec, + pub bad_event_ratelimiter: Arc>>, + pub bad_signature_ratelimiter: Arc, RateLimitState>>>, + pub bad_query_ratelimiter: Arc>>, + pub servername_ratelimiter: Arc>>>, + pub sync_receivers: RwLock>, + pub roomid_mutex_insert: RwLock>>>, + pub roomid_mutex_state: RwLock>>>, + pub roomid_mutex_federation: RwLock>>>, // this lock will be held longer + pub roomid_federationhandletime: RwLock>, + pub stateres_mutex: Arc>, + pub(crate) rotate: RotationHandler, - pub shutdown: AtomicBool, - pub argon: Argon2<'a>, + pub shutdown: AtomicBool, + pub argon: Argon2<'a>, } -/// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like. +/// Handles "rotation" of long-polling requests. "Rotation" in this context is +/// similar to "rotation" of log files and the like. /// -/// This is utilized to have sync workers return early and release read locks on the database. +/// This is utilized to have sync workers return early and release read locks on +/// the database. pub(crate) struct RotationHandler(broadcast::Sender<()>, ()); impl RotationHandler { - pub fn new() -> Self { - let (s, _r) = broadcast::channel(1); - Self(s, ()) - } + pub fn new() -> Self { + let (s, _r) = broadcast::channel(1); + Self(s, ()) + } - pub fn watch(&self) -> impl Future { - let mut r = self.0.subscribe(); + pub fn watch(&self) -> impl Future { + let mut r = self.0.subscribe(); - async move { - let _ = r.recv().await; - } - } + async move { + let _ = r.recv().await; + } + } - pub fn fire(&self) { - let _ = self.0.send(()); - } + pub fn fire(&self) { let _ = self.0.send(()); } } impl Default for RotationHandler { - fn default() -> Self { - Self::new() - } + fn default() -> Self { Self::new() } } struct Resolver { - inner: GaiResolver, - overrides: Arc>, + inner: GaiResolver, + overrides: Arc>, } impl Resolver { - fn new(overrides: Arc>) -> Self { - Resolver { - inner: GaiResolver::new(), - overrides, - } - } + fn new(overrides: Arc>) -> Self { + Resolver { + inner: GaiResolver::new(), + overrides, + } + } } impl Resolve for Resolver { - fn resolve(&self, name: Name) -> Resolving { - self.overrides - .read() - .expect("lock should not be poisoned") - .get(name.as_str()) - .and_then(|(override_name, port)| { - override_name.first().map(|first_name| { - let x: Box + Send> = - Box::new(iter::once(SocketAddr::new(*first_name, *port))); - let x: Resolving = Box::pin(future::ready(Ok(x))); - x - }) - }) - .unwrap_or_else(|| { - let this = &mut self.inner.clone(); - Box::pin(HyperService::::call(this, name).map(|result| { - result - .map(|addrs| -> Addrs { Box::new(addrs) }) - .map_err(|err| -> Box { Box::new(err) }) - })) - }) - } + fn resolve(&self, name: Name) -> Resolving { + self.overrides + .read() + .expect("lock should not be poisoned") + .get(name.as_str()) + .and_then(|(override_name, port)| { + override_name.first().map(|first_name| { + let x: Box + Send> = + Box::new(iter::once(SocketAddr::new(*first_name, *port))); + let x: Resolving = Box::pin(future::ready(Ok(x))); + x + }) + }) + .unwrap_or_else(|| { + let this = &mut self.inner.clone(); + Box::pin(HyperService::::call(this, name).map(|result| { + result + .map(|addrs| -> Addrs { Box::new(addrs) }) + .map_err(|err| -> Box { Box::new(err) }) + })) + }) + } } impl Service<'_> { - pub fn load(db: &'static dyn Data, config: Config) -> Result { - let keypair = db.load_keypair(); - - let keypair = match keypair { - Ok(k) => k, - Err(e) => { - error!("Keypair invalid. Deleting..."); - db.remove_keypair()?; - return Err(e); - } - }; - - let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new())); - - let jwt_decoding_key = config - .jwt_secret - .as_ref() - .map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); - - let url_preview_client = url_preview_reqwest_client_builder(&config)?.build()?; - let default_client = reqwest_client_builder(&config)?.build()?; - let federation_client = reqwest_client_builder(&config)? - .dns_resolver(Arc::new(Resolver::new(tls_name_override.clone()))) - .build()?; - - // Supported and stable room versions - let stable_room_versions = vec![ - RoomVersionId::V6, - RoomVersionId::V7, - RoomVersionId::V8, - RoomVersionId::V9, - RoomVersionId::V10, - ]; - // Experimental, partially supported room versions - let unstable_room_versions = vec![ - RoomVersionId::V2, - RoomVersionId::V3, - RoomVersionId::V4, - RoomVersionId::V5, - RoomVersionId::V11, - ]; - // 19456 Kib blocks, iterations = 2, parallelism = 1 for more info https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#argon2id - let argon = Argon2::new( - argon2::Algorithm::Argon2id, - argon2::Version::default(), - argon2::Params::new(19456, 2, 1, None).expect("valid parameters"), - ); - let mut s = Self { - db, - config, - keypair: Arc::new(keypair), - dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| { - error!( - "Failed to set up trust dns resolver with system config: {}", - e - ); - Error::bad_config("Failed to set up trust dns resolver with system config.") - })?, - actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())), - tls_name_override, - url_preview_client, - federation_client, - default_client, - jwt_decoding_key, - stable_room_versions, - unstable_room_versions, - bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), - bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())), - bad_query_ratelimiter: Arc::new(RwLock::new(HashMap::new())), - servername_ratelimiter: Arc::new(RwLock::new(HashMap::new())), - roomid_mutex_state: RwLock::new(HashMap::new()), - roomid_mutex_insert: RwLock::new(HashMap::new()), - roomid_mutex_federation: RwLock::new(HashMap::new()), - roomid_federationhandletime: RwLock::new(HashMap::new()), - stateres_mutex: Arc::new(Mutex::new(())), - sync_receivers: RwLock::new(HashMap::new()), - rotate: RotationHandler::new(), - shutdown: AtomicBool::new(false), - argon, - }; - - fs::create_dir_all(s.get_media_folder())?; - - if !s - .supported_room_versions() - .contains(&s.config.default_room_version) - { - error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); - s.config.default_room_version = crate::config::default_default_room_version(); - }; - - Ok(s) - } - - /// Returns this server's keypair. - pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { - &self.keypair - } - - /// Returns a reqwest client which can be used to send requests for URL previews - /// This is the same as `default_client()` except a redirect policy of max 2 is set - pub fn url_preview_client(&self) -> reqwest::Client { - // Client is cheap to clone (Arc wrapper) and avoids lifetime issues - self.url_preview_client.clone() - } - - /// Returns a reqwest client which can be used to send requests - pub fn default_client(&self) -> reqwest::Client { - // Client is cheap to clone (Arc wrapper) and avoids lifetime issues - self.default_client.clone() - } - - /// Returns a client used for resolving .well-knowns - pub fn federation_client(&self) -> reqwest::Client { - // Client is cheap to clone (Arc wrapper) and avoids lifetime issues - self.federation_client.clone() - } - - #[tracing::instrument(skip(self))] - pub fn next_count(&self) -> Result { - self.db.next_count() - } - - #[tracing::instrument(skip(self))] - pub fn current_count(&self) -> Result { - self.db.current_count() - } - - #[tracing::instrument(skip(self))] - pub fn last_check_for_updates_id(&self) -> Result { - self.db.last_check_for_updates_id() - } - - #[tracing::instrument(skip(self))] - pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { - self.db.update_check_for_updates_id(id) - } - - pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - self.db.watch(user_id, device_id).await - } - - pub fn cleanup(&self) -> Result<()> { - self.db.cleanup() - } - - pub fn server_name(&self) -> &ServerName { - self.config.server_name.as_ref() - } - - pub fn max_request_size(&self) -> u32 { - self.config.max_request_size - } - - pub fn max_fetch_prev_events(&self) -> u16 { - self.config.max_fetch_prev_events - } - - pub fn allow_registration(&self) -> bool { - self.config.allow_registration - } - - pub fn allow_guest_registration(&self) -> bool { - self.config.allow_guest_registration - } - - pub fn allow_encryption(&self) -> bool { - self.config.allow_encryption - } - - pub fn allow_federation(&self) -> bool { - self.config.allow_federation - } - - pub fn allow_public_room_directory_over_federation(&self) -> bool { - self.config.allow_public_room_directory_over_federation - } - - pub fn allow_public_room_directory_without_auth(&self) -> bool { - self.config.allow_public_room_directory_without_auth - } - - pub fn allow_device_name_federation(&self) -> bool { - self.config.allow_device_name_federation - } - - pub fn allow_room_creation(&self) -> bool { - self.config.allow_room_creation - } - - pub fn allow_unstable_room_versions(&self) -> bool { - self.config.allow_unstable_room_versions - } - - pub fn default_room_version(&self) -> RoomVersionId { - self.config.default_room_version.clone() - } - - pub fn new_user_displayname_suffix(&self) -> &String { - &self.config.new_user_displayname_suffix - } - - pub fn allow_check_for_updates(&self) -> bool { - self.config.allow_check_for_updates - } - - pub fn trusted_servers(&self) -> &[OwnedServerName] { - &self.config.trusted_servers - } - - pub fn query_trusted_key_servers_first(&self) -> bool { - self.config.query_trusted_key_servers_first - } - - pub fn dns_resolver(&self) -> &TokioAsyncResolver { - &self.dns_resolver - } + pub fn load(db: &'static dyn Data, config: Config) -> Result { + let keypair = db.load_keypair(); + + let keypair = match keypair { + Ok(k) => k, + Err(e) => { + error!("Keypair invalid. Deleting..."); + db.remove_keypair()?; + return Err(e); + }, + }; + + let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new())); + + let jwt_decoding_key = + config.jwt_secret.as_ref().map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes())); + + let url_preview_client = url_preview_reqwest_client_builder(&config)?.build()?; + let default_client = reqwest_client_builder(&config)?.build()?; + let federation_client = reqwest_client_builder(&config)? + .dns_resolver(Arc::new(Resolver::new(tls_name_override.clone()))) + .build()?; + + // Supported and stable room versions + let stable_room_versions = vec![ + RoomVersionId::V6, + RoomVersionId::V7, + RoomVersionId::V8, + RoomVersionId::V9, + RoomVersionId::V10, + ]; + // Experimental, partially supported room versions + let unstable_room_versions = vec![ + RoomVersionId::V2, + RoomVersionId::V3, + RoomVersionId::V4, + RoomVersionId::V5, + RoomVersionId::V11, + ]; + // 19456 Kib blocks, iterations = 2, parallelism = 1 for more info https://cheatsheetseries.owasp.org/cheatsheets/Password_Storage_Cheat_Sheet.html#argon2id + let argon = Argon2::new( + argon2::Algorithm::Argon2id, + argon2::Version::default(), + argon2::Params::new(19456, 2, 1, None).expect("valid parameters"), + ); + let mut s = Self { + db, + config, + keypair: Arc::new(keypair), + dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| { + error!("Failed to set up trust dns resolver with system config: {}", e); + Error::bad_config("Failed to set up trust dns resolver with system config.") + })?, + actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())), + tls_name_override, + url_preview_client, + federation_client, + default_client, + jwt_decoding_key, + stable_room_versions, + unstable_room_versions, + bad_event_ratelimiter: Arc::new(RwLock::new(HashMap::new())), + bad_signature_ratelimiter: Arc::new(RwLock::new(HashMap::new())), + bad_query_ratelimiter: Arc::new(RwLock::new(HashMap::new())), + servername_ratelimiter: Arc::new(RwLock::new(HashMap::new())), + roomid_mutex_state: RwLock::new(HashMap::new()), + roomid_mutex_insert: RwLock::new(HashMap::new()), + roomid_mutex_federation: RwLock::new(HashMap::new()), + roomid_federationhandletime: RwLock::new(HashMap::new()), + stateres_mutex: Arc::new(Mutex::new(())), + sync_receivers: RwLock::new(HashMap::new()), + rotate: RotationHandler::new(), + shutdown: AtomicBool::new(false), + argon, + }; + + fs::create_dir_all(s.get_media_folder())?; + + if !s.supported_room_versions().contains(&s.config.default_room_version) { + error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version"); + s.config.default_room_version = crate::config::default_default_room_version(); + }; - pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { - self.jwt_decoding_key.as_ref() - } + Ok(s) + } + + /// Returns this server's keypair. + pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair } + + /// Returns a reqwest client which can be used to send requests for URL + /// previews This is the same as `default_client()` except a redirect policy + /// of max 2 is set + pub fn url_preview_client(&self) -> reqwest::Client { + // Client is cheap to clone (Arc wrapper) and avoids lifetime issues + self.url_preview_client.clone() + } - pub fn turn_password(&self) -> &String { - &self.config.turn_password - } + /// Returns a reqwest client which can be used to send requests + pub fn default_client(&self) -> reqwest::Client { + // Client is cheap to clone (Arc wrapper) and avoids lifetime issues + self.default_client.clone() + } - pub fn turn_ttl(&self) -> u64 { - self.config.turn_ttl - } + /// Returns a client used for resolving .well-knowns + pub fn federation_client(&self) -> reqwest::Client { + // Client is cheap to clone (Arc wrapper) and avoids lifetime issues + self.federation_client.clone() + } - pub fn turn_uris(&self) -> &[String] { - &self.config.turn_uris - } - - pub fn turn_username(&self) -> &String { - &self.config.turn_username - } - - pub fn turn_secret(&self) -> &String { - &self.config.turn_secret - } - - pub fn notification_push_path(&self) -> &String { - &self.config.notification_push_path - } - - pub fn emergency_password(&self) -> &Option { - &self.config.emergency_password - } - - pub fn url_preview_domain_contains_allowlist(&self) -> &Vec { - &self.config.url_preview_domain_contains_allowlist - } - - pub fn url_preview_domain_explicit_allowlist(&self) -> &Vec { - &self.config.url_preview_domain_explicit_allowlist - } - - pub fn url_preview_url_contains_allowlist(&self) -> &Vec { - &self.config.url_preview_url_contains_allowlist - } - - pub fn url_preview_max_spider_size(&self) -> usize { - self.config.url_preview_max_spider_size - } - - pub fn url_preview_check_root_domain(&self) -> bool { - self.config.url_preview_check_root_domain - } - - pub fn forbidden_room_names(&self) -> &RegexSet { - &self.config.forbidden_room_names - } - - pub fn forbidden_usernames(&self) -> &RegexSet { - &self.config.forbidden_usernames - } - - pub fn allow_local_presence(&self) -> bool { - self.config.allow_local_presence - } - - pub fn allow_incoming_presence(&self) -> bool { - self.config.allow_incoming_presence - } - - pub fn allow_outgoing_presence(&self) -> bool { - self.config.allow_outgoing_presence - } - - pub fn presence_idle_timeout_s(&self) -> u64 { - self.config.presence_idle_timeout_s - } - - pub fn presence_offline_timeout_s(&self) -> u64 { - self.config.presence_offline_timeout_s - } - - pub fn rocksdb_log_level(&self) -> &String { - &self.config.rocksdb_log_level - } - - pub fn rocksdb_max_log_file_size(&self) -> usize { - self.config.rocksdb_max_log_file_size - } - - pub fn rocksdb_log_time_to_roll(&self) -> usize { - self.config.rocksdb_log_time_to_roll - } - - pub fn rocksdb_optimize_for_spinning_disks(&self) -> bool { - self.config.rocksdb_optimize_for_spinning_disks - } - - pub fn rocksdb_parallelism_threads(&self) -> usize { - self.config.rocksdb_parallelism_threads - } - - pub fn prevent_media_downloads_from(&self) -> &[OwnedServerName] { - &self.config.prevent_media_downloads_from - } - - pub fn ip_range_denylist(&self) -> &[String] { - &self.config.ip_range_denylist - } - - pub fn block_non_admin_invites(&self) -> bool { - self.config.block_non_admin_invites - } - - pub fn supported_room_versions(&self) -> Vec { - let mut room_versions: Vec = vec![]; - room_versions.extend(self.stable_room_versions.clone()); - if self.allow_unstable_room_versions() { - room_versions.extend(self.unstable_room_versions.clone()); - }; - room_versions - } - - /// TODO: the key valid until timestamp (`valid_until_ts`) is only honored in room version > 4 - /// - /// Remove the outdated keys and insert the new ones. - /// - /// This doesn't actually check that the keys provided are newer than the old set. - pub fn add_signing_key( - &self, - origin: &ServerName, - new_keys: ServerSigningKeys, - ) -> Result> { - self.db.add_signing_key(origin, new_keys) - } - - /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server. - pub fn signing_keys_for( - &self, - origin: &ServerName, - ) -> Result> { - let mut keys = self.db.signing_keys_for(origin)?; - if origin == self.server_name() { - keys.insert( - format!("ed25519:{}", services().globals.keypair().version()) - .try_into() - .expect("found invalid server signing keys in DB"), - VerifyKey { - key: Base64::new(self.keypair.public_key().to_vec()), - }, - ); - } - - Ok(keys) - } - - pub fn database_version(&self) -> Result { - self.db.database_version() - } - - pub fn bump_database_version(&self, new_version: u64) -> Result<()> { - self.db.bump_database_version(new_version) - } - - pub fn get_media_folder(&self) -> PathBuf { - let mut r = PathBuf::new(); - r.push(self.config.database_path.clone()); - r.push("media"); - r - } - - /// new SHA256 file name media function, requires "sha256_media" feature flag enabled and database migrated - /// uses SHA256 hash of the base64 key as the file name - pub fn get_media_file_new(&self, key: &[u8]) -> PathBuf { - let mut r = PathBuf::new(); - r.push(self.config.database_path.clone()); - r.push("media"); - // Using the hash of the base64 key as the filename - // This is to prevent the total length of the path from exceeding the maximum length in most filesystems - r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key))); - r - } - - /// old base64 file name media function - /// This is the old version of `get_media_file` that uses the full base64 key as the filename. - /// - /// This is deprecated and will be removed in a future release. - /// Please use `get_media_file_new` instead. - #[deprecated(note = "Use get_media_file_new instead")] - pub fn get_media_file(&self, key: &[u8]) -> PathBuf { - let mut r = PathBuf::new(); - r.push(self.config.database_path.clone()); - r.push("media"); - r.push(general_purpose::URL_SAFE_NO_PAD.encode(key)); - r - } - - pub fn well_known_client(&self) -> &Option { - &self.config.well_known_client - } - - pub fn well_known_server(&self) -> &Option { - &self.config.well_known_server - } - - pub fn unix_socket_path(&self) -> &Option { - &self.config.unix_socket_path - } - - pub fn shutdown(&self) { - self.shutdown.store(true, atomic::Ordering::Relaxed); - // On shutdown - - if self.unix_socket_path().is_some() { - match &self.unix_socket_path() { - Some(path) => { - std::fs::remove_file(path).unwrap(); - } - None => error!( - "Unable to remove socket file at {:?} during shutdown.", - &self.unix_socket_path() - ), - }; - }; - - info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); - services().globals.rotate.fire(); - } + #[tracing::instrument(skip(self))] + pub fn next_count(&self) -> Result { self.db.next_count() } + + #[tracing::instrument(skip(self))] + pub fn current_count(&self) -> Result { self.db.current_count() } + + #[tracing::instrument(skip(self))] + pub fn last_check_for_updates_id(&self) -> Result { self.db.last_check_for_updates_id() } + + #[tracing::instrument(skip(self))] + pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) } + + pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + self.db.watch(user_id, device_id).await + } + + pub fn cleanup(&self) -> Result<()> { self.db.cleanup() } + + pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() } + + pub fn max_request_size(&self) -> u32 { self.config.max_request_size } + + pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events } + + pub fn allow_registration(&self) -> bool { self.config.allow_registration } + + pub fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration } + + pub fn allow_encryption(&self) -> bool { self.config.allow_encryption } + + pub fn allow_federation(&self) -> bool { self.config.allow_federation } + + pub fn allow_public_room_directory_over_federation(&self) -> bool { + self.config.allow_public_room_directory_over_federation + } + + pub fn allow_public_room_directory_without_auth(&self) -> bool { + self.config.allow_public_room_directory_without_auth + } + + pub fn allow_device_name_federation(&self) -> bool { self.config.allow_device_name_federation } + + pub fn allow_room_creation(&self) -> bool { self.config.allow_room_creation } + + pub fn allow_unstable_room_versions(&self) -> bool { self.config.allow_unstable_room_versions } + + pub fn default_room_version(&self) -> RoomVersionId { self.config.default_room_version.clone() } + + pub fn new_user_displayname_suffix(&self) -> &String { &self.config.new_user_displayname_suffix } + + pub fn allow_check_for_updates(&self) -> bool { self.config.allow_check_for_updates } + + pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers } + + pub fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first } + + pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.dns_resolver } + + pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() } + + pub fn turn_password(&self) -> &String { &self.config.turn_password } + + pub fn turn_ttl(&self) -> u64 { self.config.turn_ttl } + + pub fn turn_uris(&self) -> &[String] { &self.config.turn_uris } + + pub fn turn_username(&self) -> &String { &self.config.turn_username } + + pub fn turn_secret(&self) -> &String { &self.config.turn_secret } + + pub fn notification_push_path(&self) -> &String { &self.config.notification_push_path } + + pub fn emergency_password(&self) -> &Option { &self.config.emergency_password } + + pub fn url_preview_domain_contains_allowlist(&self) -> &Vec { + &self.config.url_preview_domain_contains_allowlist + } + + pub fn url_preview_domain_explicit_allowlist(&self) -> &Vec { + &self.config.url_preview_domain_explicit_allowlist + } + + pub fn url_preview_url_contains_allowlist(&self) -> &Vec { &self.config.url_preview_url_contains_allowlist } + + pub fn url_preview_max_spider_size(&self) -> usize { self.config.url_preview_max_spider_size } + + pub fn url_preview_check_root_domain(&self) -> bool { self.config.url_preview_check_root_domain } + + pub fn forbidden_room_names(&self) -> &RegexSet { &self.config.forbidden_room_names } + + pub fn forbidden_usernames(&self) -> &RegexSet { &self.config.forbidden_usernames } + + pub fn allow_local_presence(&self) -> bool { self.config.allow_local_presence } + + pub fn allow_incoming_presence(&self) -> bool { self.config.allow_incoming_presence } + + pub fn allow_outgoing_presence(&self) -> bool { self.config.allow_outgoing_presence } + + pub fn presence_idle_timeout_s(&self) -> u64 { self.config.presence_idle_timeout_s } + + pub fn presence_offline_timeout_s(&self) -> u64 { self.config.presence_offline_timeout_s } + + pub fn rocksdb_log_level(&self) -> &String { &self.config.rocksdb_log_level } + + pub fn rocksdb_max_log_file_size(&self) -> usize { self.config.rocksdb_max_log_file_size } + + pub fn rocksdb_log_time_to_roll(&self) -> usize { self.config.rocksdb_log_time_to_roll } + + pub fn rocksdb_optimize_for_spinning_disks(&self) -> bool { self.config.rocksdb_optimize_for_spinning_disks } + + pub fn rocksdb_parallelism_threads(&self) -> usize { self.config.rocksdb_parallelism_threads } + + pub fn prevent_media_downloads_from(&self) -> &[OwnedServerName] { &self.config.prevent_media_downloads_from } + + pub fn ip_range_denylist(&self) -> &[String] { &self.config.ip_range_denylist } + + pub fn block_non_admin_invites(&self) -> bool { self.config.block_non_admin_invites } + + pub fn supported_room_versions(&self) -> Vec { + let mut room_versions: Vec = vec![]; + room_versions.extend(self.stable_room_versions.clone()); + if self.allow_unstable_room_versions() { + room_versions.extend(self.unstable_room_versions.clone()); + }; + room_versions + } + + /// TODO: the key valid until timestamp (`valid_until_ts`) is only honored + /// in room version > 4 + /// + /// Remove the outdated keys and insert the new ones. + /// + /// This doesn't actually check that the keys provided are newer than the + /// old set. + pub fn add_signing_key( + &self, origin: &ServerName, new_keys: ServerSigningKeys, + ) -> Result> { + self.db.add_signing_key(origin, new_keys) + } + + /// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found + /// for the server. + pub fn signing_keys_for(&self, origin: &ServerName) -> Result> { + let mut keys = self.db.signing_keys_for(origin)?; + if origin == self.server_name() { + keys.insert( + format!("ed25519:{}", services().globals.keypair().version()) + .try_into() + .expect("found invalid server signing keys in DB"), + VerifyKey { + key: Base64::new(self.keypair.public_key().to_vec()), + }, + ); + } + + Ok(keys) + } + + pub fn database_version(&self) -> Result { self.db.database_version() } + + pub fn bump_database_version(&self, new_version: u64) -> Result<()> { self.db.bump_database_version(new_version) } + + pub fn get_media_folder(&self) -> PathBuf { + let mut r = PathBuf::new(); + r.push(self.config.database_path.clone()); + r.push("media"); + r + } + + /// new SHA256 file name media function, requires "sha256_media" feature + /// flag enabled and database migrated uses SHA256 hash of the base64 key as + /// the file name + pub fn get_media_file_new(&self, key: &[u8]) -> PathBuf { + let mut r = PathBuf::new(); + r.push(self.config.database_path.clone()); + r.push("media"); + // Using the hash of the base64 key as the filename + // This is to prevent the total length of the path from exceeding the maximum + // length in most filesystems + r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key))); + r + } + + /// old base64 file name media function + /// This is the old version of `get_media_file` that uses the full base64 + /// key as the filename. + /// + /// This is deprecated and will be removed in a future release. + /// Please use `get_media_file_new` instead. + #[deprecated(note = "Use get_media_file_new instead")] + pub fn get_media_file(&self, key: &[u8]) -> PathBuf { + let mut r = PathBuf::new(); + r.push(self.config.database_path.clone()); + r.push("media"); + r.push(general_purpose::URL_SAFE_NO_PAD.encode(key)); + r + } + + pub fn well_known_client(&self) -> &Option { &self.config.well_known_client } + + pub fn well_known_server(&self) -> &Option { &self.config.well_known_server } + + pub fn unix_socket_path(&self) -> &Option { &self.config.unix_socket_path } + + pub fn shutdown(&self) { + self.shutdown.store(true, atomic::Ordering::Relaxed); + // On shutdown + + if self.unix_socket_path().is_some() { + match &self.unix_socket_path() { + Some(path) => { + std::fs::remove_file(path).unwrap(); + }, + None => error!( + "Unable to remove socket file at {:?} during shutdown.", + &self.unix_socket_path() + ), + }; + }; + + info!(target: "shutdown-sync", "Received shutdown notification, notifying sync helpers..."); + services().globals.rotate.fire(); + } } fn reqwest_client_builder(config: &Config) -> Result { - let redirect_policy = reqwest::redirect::Policy::custom(|attempt| { - if attempt.previous().len() > 6 { - attempt.error("Too many redirects (max is 6)") - } else { - attempt.follow() - } - }); + let redirect_policy = reqwest::redirect::Policy::custom(|attempt| { + if attempt.previous().len() > 6 { + attempt.error("Too many redirects (max is 6)") + } else { + attempt.follow() + } + }); - let mut reqwest_client_builder = reqwest::Client::builder() - .pool_max_idle_per_host(0) - .connect_timeout(Duration::from_secs(60)) - .timeout(Duration::from_secs(60 * 5)) - .redirect(redirect_policy) - .user_agent(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION") - )); + let mut reqwest_client_builder = reqwest::Client::builder() + .pool_max_idle_per_host(0) + .connect_timeout(Duration::from_secs(60)) + .timeout(Duration::from_secs(60 * 5)) + .redirect(redirect_policy) + .user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"))); - if let Some(proxy) = config.proxy.to_proxy()? { - reqwest_client_builder = reqwest_client_builder.proxy(proxy); - } + if let Some(proxy) = config.proxy.to_proxy()? { + reqwest_client_builder = reqwest_client_builder.proxy(proxy); + } - Ok(reqwest_client_builder) + Ok(reqwest_client_builder) } fn url_preview_reqwest_client_builder(config: &Config) -> Result { - // for security reasons (e.g. malicious open redirect), we do not want to follow too many redirects when generating URL previews. - // let's keep it at least 2 to account for HTTP -> HTTPS upgrades, if it becomes an issue we can consider raising it to 3. - let redirect_policy = reqwest::redirect::Policy::custom(|attempt| { - if attempt.previous().len() > 2 { - attempt.error("Too many redirects (max is 2)") - } else { - attempt.follow() - } - }); + // for security reasons (e.g. malicious open redirect), we do not want to follow + // too many redirects when generating URL previews. let's keep it at least 2 to + // account for HTTP -> HTTPS upgrades, if it becomes an issue we can consider + // raising it to 3. + let redirect_policy = reqwest::redirect::Policy::custom(|attempt| { + if attempt.previous().len() > 2 { + attempt.error("Too many redirects (max is 2)") + } else { + attempt.follow() + } + }); - let mut reqwest_client_builder = reqwest::Client::builder() - .pool_max_idle_per_host(0) - .connect_timeout(Duration::from_secs(60)) - .timeout(Duration::from_secs(60 * 5)) - .redirect(redirect_policy) - .user_agent(concat!( - env!("CARGO_PKG_NAME"), - "/", - env!("CARGO_PKG_VERSION") - )); + let mut reqwest_client_builder = reqwest::Client::builder() + .pool_max_idle_per_host(0) + .connect_timeout(Duration::from_secs(60)) + .timeout(Duration::from_secs(60 * 5)) + .redirect(redirect_policy) + .user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"))); - if let Some(proxy) = config.proxy.to_proxy()? { - reqwest_client_builder = reqwest_client_builder.proxy(proxy); - } + if let Some(proxy) = config.proxy.to_proxy()? { + reqwest_client_builder = reqwest_client_builder.proxy(proxy); + } - Ok(reqwest_client_builder) + Ok(reqwest_client_builder) } diff --git a/src/service/key_backups/data.rs b/src/service/key_backups/data.rs index bf640015..ac595a6b 100644 --- a/src/service/key_backups/data.rs +++ b/src/service/key_backups/data.rs @@ -1,78 +1,47 @@ use std::collections::BTreeMap; -use crate::Result; use ruma::{ - api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, - serde::Raw, - OwnedRoomId, RoomId, UserId, + api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, + serde::Raw, + OwnedRoomId, RoomId, UserId, }; +use crate::Result; + pub trait Data: Send + Sync { - fn create_backup( - &self, - user_id: &UserId, - backup_metadata: &Raw, - ) -> Result; + fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result; - fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; + fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>; - fn update_backup( - &self, - user_id: &UserId, - version: &str, - backup_metadata: &Raw, - ) -> Result; + fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw) -> Result; - fn get_latest_backup_version(&self, user_id: &UserId) -> Result>; + fn get_latest_backup_version(&self, user_id: &UserId) -> Result>; - fn get_latest_backup(&self, user_id: &UserId) - -> Result)>>; + fn get_latest_backup(&self, user_id: &UserId) -> Result)>>; - fn get_backup(&self, user_id: &UserId, version: &str) -> Result>>; + fn get_backup(&self, user_id: &UserId, version: &str) -> Result>>; - fn add_key( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - session_id: &str, - key_data: &Raw, - ) -> Result<()>; + fn add_key( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, + ) -> Result<()>; - fn count_keys(&self, user_id: &UserId, version: &str) -> Result; + fn count_keys(&self, user_id: &UserId, version: &str) -> Result; - fn get_etag(&self, user_id: &UserId, version: &str) -> Result; + fn get_etag(&self, user_id: &UserId, version: &str) -> Result; - fn get_all( - &self, - user_id: &UserId, - version: &str, - ) -> Result>; + fn get_all(&self, user_id: &UserId, version: &str) -> Result>; - fn get_room( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - ) -> Result>>; + fn get_room( + &self, user_id: &UserId, version: &str, room_id: &RoomId, + ) -> Result>>; - fn get_session( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - session_id: &str, - ) -> Result>>; + fn get_session( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, + ) -> Result>>; - fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>; + fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>; - fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>; + fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>; - fn delete_room_key( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - session_id: &str, - ) -> Result<()>; + fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()>; } diff --git a/src/service/key_backups/mod.rs b/src/service/key_backups/mod.rs index 50eca0b6..202f72d2 100644 --- a/src/service/key_backups/mod.rs +++ b/src/service/key_backups/mod.rs @@ -1,127 +1,81 @@ mod data; -pub(crate) use data::Data; - -use crate::Result; -use ruma::{ - api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, - serde::Raw, - OwnedRoomId, RoomId, UserId, -}; use std::collections::BTreeMap; +pub(crate) use data::Data; +use ruma::{ + api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup}, + serde::Raw, + OwnedRoomId, RoomId, UserId, +}; + +use crate::Result; + pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - pub fn create_backup( - &self, - user_id: &UserId, - backup_metadata: &Raw, - ) -> Result { - self.db.create_backup(user_id, backup_metadata) - } + pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw) -> Result { + self.db.create_backup(user_id, backup_metadata) + } - pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { - self.db.delete_backup(user_id, version) - } + pub fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> { + self.db.delete_backup(user_id, version) + } - pub fn update_backup( - &self, - user_id: &UserId, - version: &str, - backup_metadata: &Raw, - ) -> Result { - self.db.update_backup(user_id, version, backup_metadata) - } + pub fn update_backup( + &self, user_id: &UserId, version: &str, backup_metadata: &Raw, + ) -> Result { + self.db.update_backup(user_id, version, backup_metadata) + } - pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { - self.db.get_latest_backup_version(user_id) - } + pub fn get_latest_backup_version(&self, user_id: &UserId) -> Result> { + self.db.get_latest_backup_version(user_id) + } - pub fn get_latest_backup( - &self, - user_id: &UserId, - ) -> Result)>> { - self.db.get_latest_backup(user_id) - } + pub fn get_latest_backup(&self, user_id: &UserId) -> Result)>> { + self.db.get_latest_backup(user_id) + } - pub fn get_backup( - &self, - user_id: &UserId, - version: &str, - ) -> Result>> { - self.db.get_backup(user_id, version) - } + pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result>> { + self.db.get_backup(user_id, version) + } - pub fn add_key( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - session_id: &str, - key_data: &Raw, - ) -> Result<()> { - self.db - .add_key(user_id, version, room_id, session_id, key_data) - } + pub fn add_key( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw, + ) -> Result<()> { + self.db.add_key(user_id, version, room_id, session_id, key_data) + } - pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { - self.db.count_keys(user_id, version) - } + pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result { self.db.count_keys(user_id, version) } - pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result { - self.db.get_etag(user_id, version) - } + pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result { self.db.get_etag(user_id, version) } - pub fn get_all( - &self, - user_id: &UserId, - version: &str, - ) -> Result> { - self.db.get_all(user_id, version) - } + pub fn get_all(&self, user_id: &UserId, version: &str) -> Result> { + self.db.get_all(user_id, version) + } - pub fn get_room( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - ) -> Result>> { - self.db.get_room(user_id, version, room_id) - } + pub fn get_room( + &self, user_id: &UserId, version: &str, room_id: &RoomId, + ) -> Result>> { + self.db.get_room(user_id, version, room_id) + } - pub fn get_session( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - session_id: &str, - ) -> Result>> { - self.db.get_session(user_id, version, room_id, session_id) - } + pub fn get_session( + &self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, + ) -> Result>> { + self.db.get_session(user_id, version, room_id, session_id) + } - pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { - self.db.delete_all_keys(user_id, version) - } + pub fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> { + self.db.delete_all_keys(user_id, version) + } - pub fn delete_room_keys( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - ) -> Result<()> { - self.db.delete_room_keys(user_id, version, room_id) - } + pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> { + self.db.delete_room_keys(user_id, version, room_id) + } - pub fn delete_room_key( - &self, - user_id: &UserId, - version: &str, - room_id: &RoomId, - session_id: &str, - ) -> Result<()> { - self.db - .delete_room_key(user_id, version, room_id, session_id) - } + pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> { + self.db.delete_room_key(user_id, version, room_id, session_id) + } } diff --git a/src/service/media/data.rs b/src/service/media/data.rs index bb44de80..9da50860 100644 --- a/src/service/media/data.rs +++ b/src/service/media/data.rs @@ -1,37 +1,24 @@ use crate::Result; pub trait Data: Send + Sync { - fn create_file_metadata( - &self, - mxc: String, - width: u32, - height: u32, - content_disposition: Option<&str>, - content_type: Option<&str>, - ) -> Result>; + fn create_file_metadata( + &self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>, + ) -> Result>; - fn delete_file_mxc(&self, mxc: String) -> Result<()>; + fn delete_file_mxc(&self, mxc: String) -> Result<()>; - /// Returns content_disposition, content_type and the metadata key. - fn search_file_metadata( - &self, - mxc: String, - width: u32, - height: u32, - ) -> Result<(Option, Option, Vec)>; + /// Returns content_disposition, content_type and the metadata key. + fn search_file_metadata( + &self, mxc: String, width: u32, height: u32, + ) -> Result<(Option, Option, Vec)>; - fn search_mxc_metadata_prefix(&self, mxc: String) -> Result>>; + fn search_mxc_metadata_prefix(&self, mxc: String) -> Result>>; - fn get_all_media_keys(&self) -> Result>>; + fn get_all_media_keys(&self) -> Result>>; - fn remove_url_preview(&self, url: &str) -> Result<()>; + fn remove_url_preview(&self, url: &str) -> Result<()>; - fn set_url_preview( - &self, - url: &str, - data: &super::UrlPreviewData, - timestamp: std::time::Duration, - ) -> Result<()>; + fn set_url_preview(&self, url: &str, data: &super::UrlPreviewData, timestamp: std::time::Duration) -> Result<()>; - fn get_url_preview(&self, url: &str) -> Option; + fn get_url_preview(&self, url: &str) -> Option; } diff --git a/src/service/media/mod.rs b/src/service/media/mod.rs index 5b5a3344..8b3b002a 100644 --- a/src/service/media/mod.rs +++ b/src/service/media/mod.rs @@ -1,579 +1,480 @@ mod data; use std::{ - collections::HashMap, - io::Cursor, - sync::{Arc, RwLock}, - time::SystemTime, + collections::HashMap, + io::Cursor, + sync::{Arc, RwLock}, + time::SystemTime, }; pub(crate) use data::Data; +use image::imageops::FilterType; use ruma::OwnedMxcUri; use serde::Serialize; +use tokio::{ + fs::{self, File}, + io::{AsyncReadExt, AsyncWriteExt, BufReader}, + sync::Mutex, +}; use tracing::{debug, error}; use crate::{services, utils, Error, Result}; -use image::imageops::FilterType; - -use tokio::{ - fs::{self, File}, - io::{AsyncReadExt, AsyncWriteExt, BufReader}, - sync::Mutex, -}; #[derive(Debug)] pub struct FileMeta { - pub content_disposition: Option, - pub content_type: Option, - pub file: Vec, + pub content_disposition: Option, + pub content_type: Option, + pub file: Vec, } #[derive(Serialize, Default)] pub struct UrlPreviewData { - #[serde( - skip_serializing_if = "Option::is_none", - rename(serialize = "og:title") - )] - pub title: Option, - #[serde( - skip_serializing_if = "Option::is_none", - rename(serialize = "og:description") - )] - pub description: Option, - #[serde( - skip_serializing_if = "Option::is_none", - rename(serialize = "og:image") - )] - pub image: Option, - #[serde( - skip_serializing_if = "Option::is_none", - rename(serialize = "matrix:image:size") - )] - pub image_size: Option, - #[serde( - skip_serializing_if = "Option::is_none", - rename(serialize = "og:image:width") - )] - pub image_width: Option, - #[serde( - skip_serializing_if = "Option::is_none", - rename(serialize = "og:image:height") - )] - pub image_height: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:title"))] + pub title: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:description"))] + pub description: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image"))] + pub image: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:image:size"))] + pub image_size: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:width"))] + pub image_width: Option, + #[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:height"))] + pub image_height: Option, } pub struct Service { - pub db: &'static dyn Data, - pub url_preview_mutex: RwLock>>>, + pub db: &'static dyn Data, + pub url_preview_mutex: RwLock>>>, } impl Service { - /// Uploads a file. - pub async fn create( - &self, - mxc: String, - content_disposition: Option<&str>, - content_type: Option<&str>, - file: &[u8], - ) -> Result<()> { - // Width, Height = 0 if it's not a thumbnail - let key = self - .db - .create_file_metadata(mxc, 0, 0, content_disposition, content_type)?; + /// Uploads a file. + pub async fn create( + &self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, file: &[u8], + ) -> Result<()> { + // Width, Height = 0 if it's not a thumbnail + let key = self.db.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?; - let path = if cfg!(feature = "sha256_media") { - services().globals.get_media_file_new(&key) - } else { - #[allow(deprecated)] - services().globals.get_media_file(&key) - }; + let path = if cfg!(feature = "sha256_media") { + services().globals.get_media_file_new(&key) + } else { + #[allow(deprecated)] + services().globals.get_media_file(&key) + }; - let mut f = File::create(path).await?; - f.write_all(file).await?; - Ok(()) - } + let mut f = File::create(path).await?; + f.write_all(file).await?; + Ok(()) + } - /// Deletes a file in the database and from the media directory via an MXC - pub async fn delete(&self, mxc: String) -> Result<()> { - if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc.clone()) { - for key in keys { - let file_path = if cfg!(feature = "sha256_media") { - services().globals.get_media_file_new(&key) - } else { - #[allow(deprecated)] - services().globals.get_media_file(&key) - }; - debug!("Got local file path: {:?}", file_path); + /// Deletes a file in the database and from the media directory via an MXC + pub async fn delete(&self, mxc: String) -> Result<()> { + if let Ok(keys) = self.db.search_mxc_metadata_prefix(mxc.clone()) { + for key in keys { + let file_path = if cfg!(feature = "sha256_media") { + services().globals.get_media_file_new(&key) + } else { + #[allow(deprecated)] + services().globals.get_media_file(&key) + }; + debug!("Got local file path: {:?}", file_path); - debug!( - "Deleting local file {:?} from filesystem, original MXC: {}", - file_path, mxc - ); - tokio::fs::remove_file(file_path).await?; + debug!("Deleting local file {:?} from filesystem, original MXC: {}", file_path, mxc); + tokio::fs::remove_file(file_path).await?; - debug!("Deleting MXC {mxc} from database"); - self.db.delete_file_mxc(mxc.clone())?; - } + debug!("Deleting MXC {mxc} from database"); + self.db.delete_file_mxc(mxc.clone())?; + } - Ok(()) - } else { - error!("Failed to find any media keys for MXC \"{mxc}\" in our database (MXC does not exist)"); - Err(Error::bad_database("Failed to find any media keys for the provided MXC in our database (MXC does not exist)")) - } - } + Ok(()) + } else { + error!("Failed to find any media keys for MXC \"{mxc}\" in our database (MXC does not exist)"); + Err(Error::bad_database( + "Failed to find any media keys for the provided MXC in our database (MXC does not exist)", + )) + } + } - /// Uploads or replaces a file thumbnail. - pub async fn upload_thumbnail( - &self, - mxc: String, - content_disposition: Option<&str>, - content_type: Option<&str>, - width: u32, - height: u32, - file: &[u8], - ) -> Result<()> { - let key = - self.db - .create_file_metadata(mxc, width, height, content_disposition, content_type)?; + /// Uploads or replaces a file thumbnail. + pub async fn upload_thumbnail( + &self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, width: u32, height: u32, + file: &[u8], + ) -> Result<()> { + let key = self.db.create_file_metadata(mxc, width, height, content_disposition, content_type)?; - let path = if cfg!(feature = "sha256_media") { - services().globals.get_media_file_new(&key) - } else { - #[allow(deprecated)] - services().globals.get_media_file(&key) - }; + let path = if cfg!(feature = "sha256_media") { + services().globals.get_media_file_new(&key) + } else { + #[allow(deprecated)] + services().globals.get_media_file(&key) + }; - let mut f = File::create(path).await?; - f.write_all(file).await?; + let mut f = File::create(path).await?; + f.write_all(file).await?; - Ok(()) - } + Ok(()) + } - /// Downloads a file. - pub async fn get(&self, mxc: String) -> Result> { - if let Ok((content_disposition, content_type, key)) = - self.db.search_file_metadata(mxc, 0, 0) - { - let path = if cfg!(feature = "sha256_media") { - services().globals.get_media_file_new(&key) - } else { - #[allow(deprecated)] - services().globals.get_media_file(&key) - }; + /// Downloads a file. + pub async fn get(&self, mxc: String) -> Result> { + if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) { + let path = if cfg!(feature = "sha256_media") { + services().globals.get_media_file_new(&key) + } else { + #[allow(deprecated)] + services().globals.get_media_file(&key) + }; - let mut file = Vec::new(); - BufReader::new(File::open(path).await?) - .read_to_end(&mut file) - .await?; + let mut file = Vec::new(); + BufReader::new(File::open(path).await?).read_to_end(&mut file).await?; - Ok(Some(FileMeta { - content_disposition, - content_type, - file, - })) - } else { - Ok(None) - } - } + Ok(Some(FileMeta { + content_disposition, + content_type, + file, + })) + } else { + Ok(None) + } + } - /// Deletes all remote only media files in the given at or after time/duration. Returns a u32 - /// with the amount of media files deleted. - pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result { - if let Ok(all_keys) = self.db.get_all_media_keys() { - let user_duration: SystemTime = match cyborgtime::parse_duration(&time) { - Ok(duration) => { - debug!("Parsed duration: {:?}", duration); - debug!("System time now: {:?}", SystemTime::now()); - SystemTime::now() - duration - } - Err(e) => { - error!("Failed to parse user-specified time duration: {}", e); - return Err(Error::bad_database( - "Failed to parse user-specified time duration.", - )); - } - }; + /// Deletes all remote only media files in the given at or after + /// time/duration. Returns a u32 with the amount of media files deleted. + pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result { + if let Ok(all_keys) = self.db.get_all_media_keys() { + let user_duration: SystemTime = match cyborgtime::parse_duration(&time) { + Ok(duration) => { + debug!("Parsed duration: {:?}", duration); + debug!("System time now: {:?}", SystemTime::now()); + SystemTime::now() - duration + }, + Err(e) => { + error!("Failed to parse user-specified time duration: {}", e); + return Err(Error::bad_database("Failed to parse user-specified time duration.")); + }, + }; - let mut remote_mxcs: Vec = vec![]; + let mut remote_mxcs: Vec = vec![]; - for key in all_keys { - debug!("Full MXC key from database: {:?}", key); + for key in all_keys { + debug!("Full MXC key from database: {:?}", key); - // we need to get the MXC URL from the first part of the key (the first 0xff / 255 push) - // this code does look kinda crazy but blame conduit for using magic keys - let mut parts = key.split(|&b| b == 0xff); - let mxc = parts - .next() - .map(|bytes| { - utils::string_from_bytes(bytes).map_err(|e| { - error!("Failed to parse MXC unicode bytes from our database: {}", e); - Error::bad_database( - "Failed to parse MXC unicode bytes from our database", - ) - }) - }) - .transpose()?; + // we need to get the MXC URL from the first part of the key (the first 0xff / + // 255 push) this code does look kinda crazy but blame conduit for using magic + // keys + let mut parts = key.split(|&b| b == 0xFF); + let mxc = parts + .next() + .map(|bytes| { + utils::string_from_bytes(bytes).map_err(|e| { + error!("Failed to parse MXC unicode bytes from our database: {}", e); + Error::bad_database("Failed to parse MXC unicode bytes from our database") + }) + }) + .transpose()?; - let mxc_s = match mxc { - Some(mxc) => mxc, - None => { - return Err(Error::bad_database( - "Parsed MXC URL unicode bytes from database but still is None", - )); - } - }; + let mxc_s = match mxc { + Some(mxc) => mxc, + None => { + return Err(Error::bad_database( + "Parsed MXC URL unicode bytes from database but still is None", + )); + }, + }; - debug!("Parsed MXC key to URL: {}", mxc_s); + debug!("Parsed MXC key to URL: {}", mxc_s); - let mxc = OwnedMxcUri::from(mxc_s); - if mxc.server_name() == Ok(services().globals.server_name()) { - debug!("Ignoring local media MXC: {}", mxc); - // ignore our own MXC URLs as this would be local media. - continue; - } + let mxc = OwnedMxcUri::from(mxc_s); + if mxc.server_name() == Ok(services().globals.server_name()) { + debug!("Ignoring local media MXC: {}", mxc); + // ignore our own MXC URLs as this would be local media. + continue; + } - let path = if cfg!(feature = "sha256_media") { - services().globals.get_media_file_new(&key) - } else { - #[allow(deprecated)] - services().globals.get_media_file(&key) - }; + let path = if cfg!(feature = "sha256_media") { + services().globals.get_media_file_new(&key) + } else { + #[allow(deprecated)] + services().globals.get_media_file(&key) + }; - debug!("MXC path: {:?}", path); + debug!("MXC path: {:?}", path); - let file_metadata = fs::metadata(path.clone()).await?; - debug!("File metadata: {:?}", file_metadata); + let file_metadata = fs::metadata(path.clone()).await?; + debug!("File metadata: {:?}", file_metadata); - let file_created_at = file_metadata.created()?; - debug!("File created at: {:?}", file_created_at); + let file_created_at = file_metadata.created()?; + debug!("File created at: {:?}", file_created_at); - if file_created_at >= user_duration { - debug!("File is within user duration, pushing to list of file paths and keys to delete."); - remote_mxcs.push(mxc.to_string()); - } - } + if file_created_at >= user_duration { + debug!("File is within user duration, pushing to list of file paths and keys to delete."); + remote_mxcs.push(mxc.to_string()); + } + } - debug!("Finished going through all our media in database for eligible keys to delete, checking if these are empty"); + debug!( + "Finished going through all our media in database for eligible keys to delete, checking if these are \ + empty" + ); - if remote_mxcs.is_empty() { - return Err(Error::bad_database( - "Did not found any eligible MXCs to delete.", - )); - } + if remote_mxcs.is_empty() { + return Err(Error::bad_database("Did not found any eligible MXCs to delete.")); + } - debug!("Deleting media now in the past \"{:?}\".", user_duration); + debug!("Deleting media now in the past \"{:?}\".", user_duration); - let mut deletion_count = 0; + let mut deletion_count = 0; - for mxc in remote_mxcs { - debug!("Deleting MXC {mxc} from database and filesystem"); - self.delete(mxc).await?; - deletion_count += 1; - } + for mxc in remote_mxcs { + debug!("Deleting MXC {mxc} from database and filesystem"); + self.delete(mxc).await?; + deletion_count += 1; + } - Ok(deletion_count) - } else { - Err(Error::bad_database( - "Failed to get all our media keys (filesystem or database issue?).", - )) - } - } + Ok(deletion_count) + } else { + Err(Error::bad_database( + "Failed to get all our media keys (filesystem or database issue?).", + )) + } + } - /// Returns width, height of the thumbnail and whether it should be cropped. Returns None when - /// the server should send the original file. - pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> { - match (width, height) { - (0..=32, 0..=32) => Some((32, 32, true)), - (0..=96, 0..=96) => Some((96, 96, true)), - (0..=320, 0..=240) => Some((320, 240, false)), - (0..=640, 0..=480) => Some((640, 480, false)), - (0..=800, 0..=600) => Some((800, 600, false)), - _ => None, - } - } + /// Returns width, height of the thumbnail and whether it should be cropped. + /// Returns None when the server should send the original file. + pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> { + match (width, height) { + (0..=32, 0..=32) => Some((32, 32, true)), + (0..=96, 0..=96) => Some((96, 96, true)), + (0..=320, 0..=240) => Some((320, 240, false)), + (0..=640, 0..=480) => Some((640, 480, false)), + (0..=800, 0..=600) => Some((800, 600, false)), + _ => None, + } + } - /// Downloads a file's thumbnail. - /// - /// Here's an example on how it works: - /// - /// - Client requests an image with width=567, height=567 - /// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails - /// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96) - /// - Server creates the thumbnail and sends it to the user - /// - /// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards. - pub async fn get_thumbnail( - &self, - mxc: String, - width: u32, - height: u32, - ) -> Result> { - let (width, height, crop) = self - .thumbnail_properties(width, height) - .unwrap_or((0, 0, false)); // 0, 0 because that's the original file + /// Downloads a file's thumbnail. + /// + /// Here's an example on how it works: + /// + /// - Client requests an image with width=567, height=567 + /// - Server rounds that up to (800, 600), so it doesn't have to save too + /// many thumbnails + /// - Server rounds that up again to (958, 600) to fix the aspect ratio + /// (only for width,height>96) + /// - Server creates the thumbnail and sends it to the user + /// + /// For width,height <= 96 the server uses another thumbnailing algorithm + /// which crops the image afterwards. + pub async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result> { + let (width, height, crop) = self.thumbnail_properties(width, height).unwrap_or((0, 0, false)); // 0, 0 because that's the original file - if let Ok((content_disposition, content_type, key)) = - self.db.search_file_metadata(mxc.clone(), width, height) - { - // Using saved thumbnail - let path = if cfg!(feature = "sha256_media") { - services().globals.get_media_file_new(&key) - } else { - #[allow(deprecated)] - services().globals.get_media_file(&key) - }; + if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), width, height) { + // Using saved thumbnail + let path = if cfg!(feature = "sha256_media") { + services().globals.get_media_file_new(&key) + } else { + #[allow(deprecated)] + services().globals.get_media_file(&key) + }; - let mut file = Vec::new(); - File::open(path).await?.read_to_end(&mut file).await?; + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; - Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.clone(), - })) - } else if let Ok((content_disposition, content_type, key)) = - self.db.search_file_metadata(mxc.clone(), 0, 0) - { - // Generate a thumbnail - let path = if cfg!(feature = "sha256_media") { - services().globals.get_media_file_new(&key) - } else { - #[allow(deprecated)] - services().globals.get_media_file(&key) - }; + Ok(Some(FileMeta { + content_disposition, + content_type, + file: file.clone(), + })) + } else if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), 0, 0) { + // Generate a thumbnail + let path = if cfg!(feature = "sha256_media") { + services().globals.get_media_file_new(&key) + } else { + #[allow(deprecated)] + services().globals.get_media_file(&key) + }; - let mut file = Vec::new(); - File::open(path).await?.read_to_end(&mut file).await?; + let mut file = Vec::new(); + File::open(path).await?.read_to_end(&mut file).await?; - if let Ok(image) = image::load_from_memory(&file) { - let original_width = image.width(); - let original_height = image.height(); - if width > original_width || height > original_height { - return Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.clone(), - })); - } + if let Ok(image) = image::load_from_memory(&file) { + let original_width = image.width(); + let original_height = image.height(); + if width > original_width || height > original_height { + return Ok(Some(FileMeta { + content_disposition, + content_type, + file: file.clone(), + })); + } - let thumbnail = if crop { - image.resize_to_fill(width, height, FilterType::CatmullRom) - } else { - let (exact_width, exact_height) = { - // Copied from image::dynimage::resize_dimensions - let ratio = u64::from(original_width) * u64::from(height); - let nratio = u64::from(width) * u64::from(original_height); + let thumbnail = if crop { + image.resize_to_fill(width, height, FilterType::CatmullRom) + } else { + let (exact_width, exact_height) = { + // Copied from image::dynimage::resize_dimensions + let ratio = u64::from(original_width) * u64::from(height); + let nratio = u64::from(width) * u64::from(original_height); - let use_width = nratio <= ratio; - let intermediate = if use_width { - u64::from(original_height) * u64::from(width) - / u64::from(original_width) - } else { - u64::from(original_width) * u64::from(height) - / u64::from(original_height) - }; - if use_width { - if intermediate <= u64::from(::std::u32::MAX) { - (width, intermediate as u32) - } else { - ( - (u64::from(width) * u64::from(::std::u32::MAX) / intermediate) - as u32, - ::std::u32::MAX, - ) - } - } else if intermediate <= u64::from(::std::u32::MAX) { - (intermediate as u32, height) - } else { - ( - ::std::u32::MAX, - (u64::from(height) * u64::from(::std::u32::MAX) / intermediate) - as u32, - ) - } - }; + let use_width = nratio <= ratio; + let intermediate = if use_width { + u64::from(original_height) * u64::from(width) / u64::from(original_width) + } else { + u64::from(original_width) * u64::from(height) / u64::from(original_height) + }; + if use_width { + if intermediate <= u64::from(::std::u32::MAX) { + (width, intermediate as u32) + } else { + ( + (u64::from(width) * u64::from(::std::u32::MAX) / intermediate) as u32, + ::std::u32::MAX, + ) + } + } else if intermediate <= u64::from(::std::u32::MAX) { + (intermediate as u32, height) + } else { + ( + ::std::u32::MAX, + (u64::from(height) * u64::from(::std::u32::MAX) / intermediate) as u32, + ) + } + }; - image.thumbnail_exact(exact_width, exact_height) - }; + image.thumbnail_exact(exact_width, exact_height) + }; - let mut thumbnail_bytes = Vec::new(); - thumbnail.write_to( - &mut Cursor::new(&mut thumbnail_bytes), - image::ImageOutputFormat::Png, - )?; + let mut thumbnail_bytes = Vec::new(); + thumbnail.write_to(&mut Cursor::new(&mut thumbnail_bytes), image::ImageOutputFormat::Png)?; - // Save thumbnail in database so we don't have to generate it again next time - let thumbnail_key = self.db.create_file_metadata( - mxc, - width, - height, - content_disposition.as_deref(), - content_type.as_deref(), - )?; + // Save thumbnail in database so we don't have to generate it again next time + let thumbnail_key = self.db.create_file_metadata( + mxc, + width, + height, + content_disposition.as_deref(), + content_type.as_deref(), + )?; - let path = if cfg!(feature = "sha256_media") { - services().globals.get_media_file_new(&thumbnail_key) - } else { - #[allow(deprecated)] - services().globals.get_media_file(&thumbnail_key) - }; + let path = if cfg!(feature = "sha256_media") { + services().globals.get_media_file_new(&thumbnail_key) + } else { + #[allow(deprecated)] + services().globals.get_media_file(&thumbnail_key) + }; - let mut f = File::create(path).await?; - f.write_all(&thumbnail_bytes).await?; + let mut f = File::create(path).await?; + f.write_all(&thumbnail_bytes).await?; - Ok(Some(FileMeta { - content_disposition, - content_type, - file: thumbnail_bytes.clone(), - })) - } else { - // Couldn't parse file to generate thumbnail, send original - Ok(Some(FileMeta { - content_disposition, - content_type, - file: file.clone(), - })) - } - } else { - Ok(None) - } - } + Ok(Some(FileMeta { + content_disposition, + content_type, + file: thumbnail_bytes.clone(), + })) + } else { + // Couldn't parse file to generate thumbnail, send original + Ok(Some(FileMeta { + content_disposition, + content_type, + file: file.clone(), + })) + } + } else { + Ok(None) + } + } - pub async fn get_url_preview(&self, url: &str) -> Option { - self.db.get_url_preview(url) - } + pub async fn get_url_preview(&self, url: &str) -> Option { self.db.get_url_preview(url) } - pub async fn remove_url_preview(&self, url: &str) -> Result<()> { - // TODO: also remove the downloaded image - self.db.remove_url_preview(url) - } + pub async fn remove_url_preview(&self, url: &str) -> Result<()> { + // TODO: also remove the downloaded image + self.db.remove_url_preview(url) + } - pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> { - let now = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .expect("valid system time"); - self.db.set_url_preview(url, data, now) - } + pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> { + let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).expect("valid system time"); + self.db.set_url_preview(url, data, now) + } } #[cfg(test)] mod tests { - use std::path::PathBuf; + use std::path::PathBuf; - use sha2::Digest; + use base64::{engine::general_purpose, Engine as _}; + use sha2::Digest; - use base64::{engine::general_purpose, Engine as _}; + use super::*; - use super::*; + struct MockedKVDatabase; - struct MockedKVDatabase; + impl Data for MockedKVDatabase { + fn create_file_metadata( + &self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>, + ) -> Result> { + // copied from src/database/key_value/media.rs + let mut key = mxc.as_bytes().to_vec(); + key.push(0xFF); + key.extend_from_slice(&width.to_be_bytes()); + key.extend_from_slice(&height.to_be_bytes()); + key.push(0xFF); + key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default()); + key.push(0xFF); + key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default()); - impl Data for MockedKVDatabase { - fn create_file_metadata( - &self, - mxc: String, - width: u32, - height: u32, - content_disposition: Option<&str>, - content_type: Option<&str>, - ) -> Result> { - // copied from src/database/key_value/media.rs - let mut key = mxc.as_bytes().to_vec(); - key.push(0xff); - key.extend_from_slice(&width.to_be_bytes()); - key.extend_from_slice(&height.to_be_bytes()); - key.push(0xff); - key.extend_from_slice( - content_disposition - .as_ref() - .map(|f| f.as_bytes()) - .unwrap_or_default(), - ); - key.push(0xff); - key.extend_from_slice( - content_type - .as_ref() - .map(|c| c.as_bytes()) - .unwrap_or_default(), - ); + Ok(key) + } - Ok(key) - } + fn delete_file_mxc(&self, _mxc: String) -> Result<()> { todo!() } - fn delete_file_mxc(&self, _mxc: String) -> Result<()> { - todo!() - } + fn search_mxc_metadata_prefix(&self, _mxc: String) -> Result>> { todo!() } - fn search_mxc_metadata_prefix(&self, _mxc: String) -> Result>> { - todo!() - } + fn get_all_media_keys(&self) -> Result>> { todo!() } - fn get_all_media_keys(&self) -> Result>> { - todo!() - } + fn search_file_metadata( + &self, _mxc: String, _width: u32, _height: u32, + ) -> Result<(Option, Option, Vec)> { + todo!() + } - fn search_file_metadata( - &self, - _mxc: String, - _width: u32, - _height: u32, - ) -> Result<(Option, Option, Vec)> { - todo!() - } + fn remove_url_preview(&self, _url: &str) -> Result<()> { todo!() } - fn remove_url_preview(&self, _url: &str) -> Result<()> { - todo!() - } + fn set_url_preview(&self, _url: &str, _data: &UrlPreviewData, _timestamp: std::time::Duration) -> Result<()> { + todo!() + } - fn set_url_preview( - &self, - _url: &str, - _data: &UrlPreviewData, - _timestamp: std::time::Duration, - ) -> Result<()> { - todo!() - } + fn get_url_preview(&self, _url: &str) -> Option { todo!() } + } - fn get_url_preview(&self, _url: &str) -> Option { - todo!() - } - } + #[tokio::test] + async fn long_file_names_works() { + static DB: MockedKVDatabase = MockedKVDatabase; + let media = Service { + db: &DB, + url_preview_mutex: RwLock::new(HashMap::new()), + }; - #[tokio::test] - async fn long_file_names_works() { - static DB: MockedKVDatabase = MockedKVDatabase; - let media = Service { - db: &DB, - url_preview_mutex: RwLock::new(HashMap::new()), - }; - - let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned(); - let width = 100; - let height = 100; - let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special characters like äöüß and even emoji like 🦀.png\""; - let content_type = "image/png"; - let key = media - .db - .create_file_metadata( - mxc, - width, - height, - Some(content_disposition), - Some(content_type), - ) - .unwrap(); - let mut r = PathBuf::new(); - r.push("/tmp"); - r.push("media"); - // r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD)); - // use the sha256 hash of the key as the file name instead of the key itself - // this is because the base64 encoded key can be longer than 255 characters. - r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key))); - // Check that the file path is not longer than 255 characters - // (255 is the maximum length of a file path on most file systems) - assert!( - r.to_str().unwrap().len() <= 255, - "File path is too long: {}", - r.to_str().unwrap().len() - ); - } + let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned(); + let width = 100; + let height = 100; + let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special \ + characters like äöüß and even emoji like 🦀.png\""; + let content_type = "image/png"; + let key = + media.db.create_file_metadata(mxc, width, height, Some(content_disposition), Some(content_type)).unwrap(); + let mut r = PathBuf::new(); + r.push("/tmp"); + r.push("media"); + // r.push(base64::encode_config(key, base64::URL_SAFE_NO_PAD)); + // use the sha256 hash of the key as the file name instead of the key itself + // this is because the base64 encoded key can be longer than 255 characters. + r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key))); + // Check that the file path is not longer than 255 characters + // (255 is the maximum length of a file path on most file systems) + assert!( + r.to_str().unwrap().len() <= 255, + "File path is too long: {}", + r.to_str().unwrap().len() + ); + } } diff --git a/src/service/mod.rs b/src/service/mod.rs index 1902fa8c..9bee19fd 100644 --- a/src/service/mod.rs +++ b/src/service/mod.rs @@ -1,6 +1,6 @@ use std::{ - collections::{BTreeMap, HashMap}, - sync::{Arc, Mutex, RwLock}, + collections::{BTreeMap, HashMap}, + sync::{Arc, Mutex, RwLock}, }; use lru_cache::LruCache; @@ -22,210 +22,186 @@ pub(crate) mod uiaa; pub(crate) mod users; pub struct Services<'a> { - pub appservice: appservice::Service, - pub pusher: pusher::Service, - pub rooms: rooms::Service, - pub transaction_ids: transaction_ids::Service, - pub uiaa: uiaa::Service, - pub users: users::Service, - pub account_data: account_data::Service, - pub admin: Arc, - pub globals: globals::Service<'a>, - pub key_backups: key_backups::Service, - pub media: media::Service, - pub sending: Arc, + pub appservice: appservice::Service, + pub pusher: pusher::Service, + pub rooms: rooms::Service, + pub transaction_ids: transaction_ids::Service, + pub uiaa: uiaa::Service, + pub users: users::Service, + pub account_data: account_data::Service, + pub admin: Arc, + pub globals: globals::Service<'a>, + pub key_backups: key_backups::Service, + pub media: media::Service, + pub sending: Arc, } impl Services<'_> { - pub fn build< - D: appservice::Data - + pusher::Data - + rooms::Data - + transaction_ids::Data - + uiaa::Data - + users::Data - + account_data::Data - + globals::Data - + key_backups::Data - + media::Data - + sending::Data - + 'static, - >( - db: &'static D, - config: Config, - ) -> Result { - Ok(Self { - appservice: appservice::Service { db }, - pusher: pusher::Service { db }, - rooms: rooms::Service { - alias: rooms::alias::Service { db }, - auth_chain: rooms::auth_chain::Service { db }, - directory: rooms::directory::Service { db }, - edus: rooms::edus::Service { - presence: rooms::edus::presence::Service { db }, - read_receipt: rooms::edus::read_receipt::Service { db }, - typing: rooms::edus::typing::Service { db }, - }, - event_handler: rooms::event_handler::Service, - lazy_loading: rooms::lazy_loading::Service { - db, - lazy_load_waiting: Mutex::new(HashMap::new()), - }, - metadata: rooms::metadata::Service { db }, - outlier: rooms::outlier::Service { db }, - pdu_metadata: rooms::pdu_metadata::Service { db }, - search: rooms::search::Service { db }, - short: rooms::short::Service { db }, - state: rooms::state::Service { db }, - state_accessor: rooms::state_accessor::Service { - db, - server_visibility_cache: Mutex::new(LruCache::new( - (100.0 * config.conduit_cache_capacity_modifier) as usize, - )), - user_visibility_cache: Mutex::new(LruCache::new( - (100.0 * config.conduit_cache_capacity_modifier) as usize, - )), - }, - state_cache: rooms::state_cache::Service { db }, - state_compressor: rooms::state_compressor::Service { - db, - stateinfo_cache: Mutex::new(LruCache::new( - (100.0 * config.conduit_cache_capacity_modifier) as usize, - )), - }, - timeline: rooms::timeline::Service { - db, - lasttimelinecount_cache: Mutex::new(HashMap::new()), - }, - threads: rooms::threads::Service { db }, - spaces: rooms::spaces::Service { - roomid_spacechunk_cache: Mutex::new(LruCache::new( - (100.0 * config.conduit_cache_capacity_modifier) as usize, - )), - }, - user: rooms::user::Service { db }, - }, - transaction_ids: transaction_ids::Service { db }, - uiaa: uiaa::Service { db }, - users: users::Service { - db, - connections: Mutex::new(BTreeMap::new()), - }, - account_data: account_data::Service { db }, - admin: admin::Service::build(), - key_backups: key_backups::Service { db }, - media: media::Service { - db, - url_preview_mutex: RwLock::new(HashMap::new()), - }, - sending: sending::Service::build(db, &config), + pub fn build< + D: appservice::Data + + pusher::Data + + rooms::Data + + transaction_ids::Data + + uiaa::Data + + users::Data + + account_data::Data + + globals::Data + + key_backups::Data + + media::Data + + sending::Data + + 'static, + >( + db: &'static D, config: Config, + ) -> Result { + Ok(Self { + appservice: appservice::Service { + db, + }, + pusher: pusher::Service { + db, + }, + rooms: rooms::Service { + alias: rooms::alias::Service { + db, + }, + auth_chain: rooms::auth_chain::Service { + db, + }, + directory: rooms::directory::Service { + db, + }, + edus: rooms::edus::Service { + presence: rooms::edus::presence::Service { + db, + }, + read_receipt: rooms::edus::read_receipt::Service { + db, + }, + typing: rooms::edus::typing::Service { + db, + }, + }, + event_handler: rooms::event_handler::Service, + lazy_loading: rooms::lazy_loading::Service { + db, + lazy_load_waiting: Mutex::new(HashMap::new()), + }, + metadata: rooms::metadata::Service { + db, + }, + outlier: rooms::outlier::Service { + db, + }, + pdu_metadata: rooms::pdu_metadata::Service { + db, + }, + search: rooms::search::Service { + db, + }, + short: rooms::short::Service { + db, + }, + state: rooms::state::Service { + db, + }, + state_accessor: rooms::state_accessor::Service { + db, + server_visibility_cache: Mutex::new(LruCache::new( + (100.0 * config.conduit_cache_capacity_modifier) as usize, + )), + user_visibility_cache: Mutex::new(LruCache::new( + (100.0 * config.conduit_cache_capacity_modifier) as usize, + )), + }, + state_cache: rooms::state_cache::Service { + db, + }, + state_compressor: rooms::state_compressor::Service { + db, + stateinfo_cache: Mutex::new(LruCache::new( + (100.0 * config.conduit_cache_capacity_modifier) as usize, + )), + }, + timeline: rooms::timeline::Service { + db, + lasttimelinecount_cache: Mutex::new(HashMap::new()), + }, + threads: rooms::threads::Service { + db, + }, + spaces: rooms::spaces::Service { + roomid_spacechunk_cache: Mutex::new(LruCache::new( + (100.0 * config.conduit_cache_capacity_modifier) as usize, + )), + }, + user: rooms::user::Service { + db, + }, + }, + transaction_ids: transaction_ids::Service { + db, + }, + uiaa: uiaa::Service { + db, + }, + users: users::Service { + db, + connections: Mutex::new(BTreeMap::new()), + }, + account_data: account_data::Service { + db, + }, + admin: admin::Service::build(), + key_backups: key_backups::Service { + db, + }, + media: media::Service { + db, + url_preview_mutex: RwLock::new(HashMap::new()), + }, + sending: sending::Service::build(db, &config), - globals: globals::Service::load(db, config)?, - }) - } - fn memory_usage(&self) -> String { - let lazy_load_waiting = self - .rooms - .lazy_loading - .lazy_load_waiting - .lock() - .unwrap() - .len(); - let server_visibility_cache = self - .rooms - .state_accessor - .server_visibility_cache - .lock() - .unwrap() - .len(); - let user_visibility_cache = self - .rooms - .state_accessor - .user_visibility_cache - .lock() - .unwrap() - .len(); - let stateinfo_cache = self - .rooms - .state_compressor - .stateinfo_cache - .lock() - .unwrap() - .len(); - let lasttimelinecount_cache = self - .rooms - .timeline - .lasttimelinecount_cache - .lock() - .unwrap() - .len(); - let roomid_spacechunk_cache = self - .rooms - .spaces - .roomid_spacechunk_cache - .lock() - .unwrap() - .len(); + globals: globals::Service::load(db, config)?, + }) + } - format!( - "\ + fn memory_usage(&self) -> String { + let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().len(); + let server_visibility_cache = self.rooms.state_accessor.server_visibility_cache.lock().unwrap().len(); + let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len(); + let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len(); + let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().len(); + let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().len(); + + format!( + "\ lazy_load_waiting: {lazy_load_waiting} server_visibility_cache: {server_visibility_cache} user_visibility_cache: {user_visibility_cache} stateinfo_cache: {stateinfo_cache} lasttimelinecount_cache: {lasttimelinecount_cache} -roomid_spacechunk_cache: {roomid_spacechunk_cache}\ - " - ) - } - fn clear_caches(&self, amount: u32) { - if amount > 0 { - self.rooms - .lazy_loading - .lazy_load_waiting - .lock() - .unwrap() - .clear(); - } - if amount > 1 { - self.rooms - .state_accessor - .server_visibility_cache - .lock() - .unwrap() - .clear(); - } - if amount > 2 { - self.rooms - .state_accessor - .user_visibility_cache - .lock() - .unwrap() - .clear(); - } - if amount > 3 { - self.rooms - .state_compressor - .stateinfo_cache - .lock() - .unwrap() - .clear(); - } - if amount > 4 { - self.rooms - .timeline - .lasttimelinecount_cache - .lock() - .unwrap() - .clear(); - } - if amount > 5 { - self.rooms - .spaces - .roomid_spacechunk_cache - .lock() - .unwrap() - .clear(); - } - } +roomid_spacechunk_cache: {roomid_spacechunk_cache}" + ) + } + + fn clear_caches(&self, amount: u32) { + if amount > 0 { + self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().clear(); + } + if amount > 1 { + self.rooms.state_accessor.server_visibility_cache.lock().unwrap().clear(); + } + if amount > 2 { + self.rooms.state_accessor.user_visibility_cache.lock().unwrap().clear(); + } + if amount > 3 { + self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear(); + } + if amount > 4 { + self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().clear(); + } + if amount > 5 { + self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().clear(); + } + } } diff --git a/src/service/pdu.rs b/src/service/pdu.rs index 0a9ea861..a7eafb06 100644 --- a/src/service/pdu.rs +++ b/src/service/pdu.rs @@ -1,410 +1,372 @@ -use crate::Error; +use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; + use ruma::{ - canonical_json::redact_content_in_place, - events::{ - room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent, - AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, - AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType, - }, - serde::Raw, - state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, - OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, UInt, UserId, + canonical_json::redact_content_in_place, + events::{ + room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent, AnyEphemeralRoomEvent, + AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent, + AnyTimelineEvent, StateEvent, TimelineEventType, + }, + serde::Raw, + state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId, + OwnedUserId, RoomId, RoomVersionId, UInt, UserId, }; use serde::{Deserialize, Serialize}; use serde_json::{ - json, - value::{to_raw_value, RawValue as RawJsonValue}, + json, + value::{to_raw_value, RawValue as RawJsonValue}, }; -use std::{cmp::Ordering, collections::BTreeMap, sync::Arc}; use tracing::warn; +use crate::Error; + /// Content hashes of a PDU. #[derive(Clone, Debug, Deserialize, Serialize)] pub struct EventHash { - /// The SHA-256 hash. - pub sha256: String, + /// The SHA-256 hash. + pub sha256: String, } #[derive(Clone, Deserialize, Serialize, Debug)] pub struct PduEvent { - pub event_id: Arc, - pub room_id: OwnedRoomId, - pub sender: OwnedUserId, - pub origin_server_ts: UInt, - #[serde(rename = "type")] - pub kind: TimelineEventType, - pub content: Box, - #[serde(skip_serializing_if = "Option::is_none")] - pub state_key: Option, - pub prev_events: Vec>, - pub depth: UInt, - pub auth_events: Vec>, - #[serde(skip_serializing_if = "Option::is_none")] - pub redacts: Option>, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub unsigned: Option>, - pub hashes: EventHash, - #[serde(default, skip_serializing_if = "Option::is_none")] - pub signatures: Option>, // BTreeMap, BTreeMap> + pub event_id: Arc, + pub room_id: OwnedRoomId, + pub sender: OwnedUserId, + pub origin_server_ts: UInt, + #[serde(rename = "type")] + pub kind: TimelineEventType, + pub content: Box, + #[serde(skip_serializing_if = "Option::is_none")] + pub state_key: Option, + pub prev_events: Vec>, + pub depth: UInt, + pub auth_events: Vec>, + #[serde(skip_serializing_if = "Option::is_none")] + pub redacts: Option>, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub unsigned: Option>, + pub hashes: EventHash, + #[serde(default, skip_serializing_if = "Option::is_none")] + pub signatures: Option>, // BTreeMap, BTreeMap> } impl PduEvent { - #[tracing::instrument(skip(self))] - pub fn redact( - &mut self, - room_version_id: RoomVersionId, - reason: &PduEvent, - ) -> crate::Result<()> { - self.unsigned = None; + #[tracing::instrument(skip(self))] + pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &PduEvent) -> crate::Result<()> { + self.unsigned = None; - let mut content = serde_json::from_str(self.content.get()) - .map_err(|_| Error::bad_database("PDU in db has invalid content."))?; - redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) - .map_err(|e| Error::RedactionError(self.sender.server_name().to_owned(), e))?; + let mut content = serde_json::from_str(self.content.get()) + .map_err(|_| Error::bad_database("PDU in db has invalid content."))?; + redact_content_in_place(&mut content, &room_version_id, self.kind.to_string()) + .map_err(|e| Error::RedactionError(self.sender.server_name().to_owned(), e))?; - self.unsigned = Some(to_raw_value(&json!({ - "redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works") - })).expect("to string always works")); + self.unsigned = Some( + to_raw_value(&json!({ + "redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works") + })) + .expect("to string always works"), + ); - self.content = to_raw_value(&content).expect("to string always works"); + self.content = to_raw_value(&content).expect("to string always works"); - Ok(()) - } + Ok(()) + } - pub fn remove_transaction_id(&mut self) -> crate::Result<()> { - if let Some(unsigned) = &self.unsigned { - let mut unsigned: BTreeMap> = - serde_json::from_str(unsigned.get()) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; - unsigned.remove("transaction_id"); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); - } + pub fn remove_transaction_id(&mut self) -> crate::Result<()> { + if let Some(unsigned) = &self.unsigned { + let mut unsigned: BTreeMap> = serde_json::from_str(unsigned.get()) + .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; + unsigned.remove("transaction_id"); + self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); + } - Ok(()) - } + Ok(()) + } - pub fn add_age(&mut self) -> crate::Result<()> { - let mut unsigned: BTreeMap> = self - .unsigned - .as_ref() - .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) - .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; + pub fn add_age(&mut self) -> crate::Result<()> { + let mut unsigned: BTreeMap> = self + .unsigned + .as_ref() + .map_or_else(|| Ok(BTreeMap::new()), |u| serde_json::from_str(u.get())) + .map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?; - unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap()); - self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); + unsigned.insert("age".to_owned(), to_raw_value(&1).unwrap()); + self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid")); - Ok(()) - } + Ok(()) + } - #[tracing::instrument(skip(self))] - pub fn to_sync_room_event(&self) -> Raw { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - }); + #[tracing::instrument(skip(self))] + pub fn to_sync_room_event(&self) -> Raw { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + }); - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &self.redacts { - json["redacts"] = json!(redacts); - } + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &self.redacts { + json["redacts"] = json!(redacts); + } - serde_json::from_value(json).expect("Raw::from_value always works") - } + serde_json::from_value(json).expect("Raw::from_value always works") + } - /// This only works for events that are also AnyRoomEvents. - #[tracing::instrument(skip(self))] - pub fn to_any_event(&self) -> Raw { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - }); + /// This only works for events that are also AnyRoomEvents. + #[tracing::instrument(skip(self))] + pub fn to_any_event(&self) -> Raw { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &self.redacts { - json["redacts"] = json!(redacts); - } + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &self.redacts { + json["redacts"] = json!(redacts); + } - serde_json::from_value(json).expect("Raw::from_value always works") - } + serde_json::from_value(json).expect("Raw::from_value always works") + } - #[tracing::instrument(skip(self))] - pub fn to_room_event(&self) -> Raw { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - }); + #[tracing::instrument(skip(self))] + pub fn to_room_event(&self) -> Raw { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &self.redacts { - json["redacts"] = json!(redacts); - } + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &self.redacts { + json["redacts"] = json!(redacts); + } - serde_json::from_value(json).expect("Raw::from_value always works") - } + serde_json::from_value(json).expect("Raw::from_value always works") + } - #[tracing::instrument(skip(self))] - pub fn to_message_like_event(&self) -> Raw { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - }); + #[tracing::instrument(skip(self))] + pub fn to_message_like_event(&self) -> Raw { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + }); - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } - if let Some(state_key) = &self.state_key { - json["state_key"] = json!(state_key); - } - if let Some(redacts) = &self.redacts { - json["redacts"] = json!(redacts); - } + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } + if let Some(state_key) = &self.state_key { + json["state_key"] = json!(state_key); + } + if let Some(redacts) = &self.redacts { + json["redacts"] = json!(redacts); + } - serde_json::from_value(json).expect("Raw::from_value always works") - } + serde_json::from_value(json).expect("Raw::from_value always works") + } - #[tracing::instrument(skip(self))] - pub fn to_state_event(&self) -> Raw { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "room_id": self.room_id, - "state_key": self.state_key, - }); + #[tracing::instrument(skip(self))] + pub fn to_state_event(&self) -> Raw { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "room_id": self.room_id, + "state_key": self.state_key, + }); - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } - serde_json::from_value(json).expect("Raw::from_value always works") - } + serde_json::from_value(json).expect("Raw::from_value always works") + } - #[tracing::instrument(skip(self))] - pub fn to_sync_state_event(&self) -> Raw { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "state_key": self.state_key, - }); + #[tracing::instrument(skip(self))] + pub fn to_sync_state_event(&self) -> Raw { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "state_key": self.state_key, + }); - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } - serde_json::from_value(json).expect("Raw::from_value always works") - } + serde_json::from_value(json).expect("Raw::from_value always works") + } - #[tracing::instrument(skip(self))] - pub fn to_stripped_state_event(&self) -> Raw { - let json = json!({ - "content": self.content, - "type": self.kind, - "sender": self.sender, - "state_key": self.state_key, - }); + #[tracing::instrument(skip(self))] + pub fn to_stripped_state_event(&self) -> Raw { + let json = json!({ + "content": self.content, + "type": self.kind, + "sender": self.sender, + "state_key": self.state_key, + }); - serde_json::from_value(json).expect("Raw::from_value always works") - } + serde_json::from_value(json).expect("Raw::from_value always works") + } - #[tracing::instrument(skip(self))] - pub fn to_stripped_spacechild_state_event(&self) -> Raw { - let json = json!({ - "content": self.content, - "type": self.kind, - "sender": self.sender, - "state_key": self.state_key, - "origin_server_ts": self.origin_server_ts, - }); + #[tracing::instrument(skip(self))] + pub fn to_stripped_spacechild_state_event(&self) -> Raw { + let json = json!({ + "content": self.content, + "type": self.kind, + "sender": self.sender, + "state_key": self.state_key, + "origin_server_ts": self.origin_server_ts, + }); - serde_json::from_value(json).expect("Raw::from_value always works") - } + serde_json::from_value(json).expect("Raw::from_value always works") + } - #[tracing::instrument(skip(self))] - pub fn to_member_event(&self) -> Raw> { - let mut json = json!({ - "content": self.content, - "type": self.kind, - "event_id": self.event_id, - "sender": self.sender, - "origin_server_ts": self.origin_server_ts, - "redacts": self.redacts, - "room_id": self.room_id, - "state_key": self.state_key, - }); + #[tracing::instrument(skip(self))] + pub fn to_member_event(&self) -> Raw> { + let mut json = json!({ + "content": self.content, + "type": self.kind, + "event_id": self.event_id, + "sender": self.sender, + "origin_server_ts": self.origin_server_ts, + "redacts": self.redacts, + "room_id": self.room_id, + "state_key": self.state_key, + }); - if let Some(unsigned) = &self.unsigned { - json["unsigned"] = json!(unsigned); - } + if let Some(unsigned) = &self.unsigned { + json["unsigned"] = json!(unsigned); + } - serde_json::from_value(json).expect("Raw::from_value always works") - } + serde_json::from_value(json).expect("Raw::from_value always works") + } - /// This does not return a full `Pdu` it is only to satisfy ruma's types. - #[tracing::instrument] - pub fn convert_to_outgoing_federation_event( - mut pdu_json: CanonicalJsonObject, - ) -> Box { - if let Some(unsigned) = pdu_json - .get_mut("unsigned") - .and_then(|val| val.as_object_mut()) - { - unsigned.remove("transaction_id"); - } + /// This does not return a full `Pdu` it is only to satisfy ruma's types. + #[tracing::instrument] + pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box { + if let Some(unsigned) = pdu_json.get_mut("unsigned").and_then(|val| val.as_object_mut()) { + unsigned.remove("transaction_id"); + } - pdu_json.remove("event_id"); + pdu_json.remove("event_id"); - // TODO: another option would be to convert it to a canonical string to validate size - // and return a Result> - // serde_json::from_str::>( - // ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is valid serde_json::Value"), - // ) - // .expect("Raw::from_value always works") + // TODO: another option would be to convert it to a canonical string to validate + // size and return a Result> + // serde_json::from_str::>( + // ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is + // valid serde_json::Value"), ) + // .expect("Raw::from_value always works") - to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") - } + to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value") + } - pub fn from_id_val( - event_id: &EventId, - mut json: CanonicalJsonObject, - ) -> Result { - json.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); + pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result { + json.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); - serde_json::from_value(serde_json::to_value(json).expect("valid JSON")) - } + serde_json::from_value(serde_json::to_value(json).expect("valid JSON")) + } } impl state_res::Event for PduEvent { - type Id = Arc; + type Id = Arc; - fn event_id(&self) -> &Self::Id { - &self.event_id - } + fn event_id(&self) -> &Self::Id { &self.event_id } - fn room_id(&self) -> &RoomId { - &self.room_id - } + fn room_id(&self) -> &RoomId { &self.room_id } - fn sender(&self) -> &UserId { - &self.sender - } + fn sender(&self) -> &UserId { &self.sender } - fn event_type(&self) -> &TimelineEventType { - &self.kind - } + fn event_type(&self) -> &TimelineEventType { &self.kind } - fn content(&self) -> &RawJsonValue { - &self.content - } + fn content(&self) -> &RawJsonValue { &self.content } - fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { - MilliSecondsSinceUnixEpoch(self.origin_server_ts) - } + fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch(self.origin_server_ts) } - fn state_key(&self) -> Option<&str> { - self.state_key.as_deref() - } + fn state_key(&self) -> Option<&str> { self.state_key.as_deref() } - fn prev_events(&self) -> Box + '_> { - Box::new(self.prev_events.iter()) - } + fn prev_events(&self) -> Box + '_> { Box::new(self.prev_events.iter()) } - fn auth_events(&self) -> Box + '_> { - Box::new(self.auth_events.iter()) - } + fn auth_events(&self) -> Box + '_> { Box::new(self.auth_events.iter()) } - fn redacts(&self) -> Option<&Self::Id> { - self.redacts.as_ref() - } + fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() } } // These impl's allow us to dedup state snapshots when resolving state // for incoming events (federation/send/{txn}). impl Eq for PduEvent {} impl PartialEq for PduEvent { - fn eq(&self, other: &Self) -> bool { - self.event_id == other.event_id - } + fn eq(&self, other: &Self) -> bool { self.event_id == other.event_id } } impl PartialOrd for PduEvent { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for PduEvent { - fn cmp(&self, other: &Self) -> Ordering { - self.event_id.cmp(&other.event_id) - } + fn cmp(&self, other: &Self) -> Ordering { self.event_id.cmp(&other.event_id) } } /// Generates a correct eventId for the incoming pdu. /// -/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap`. +/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap`. pub(crate) fn gen_event_id_canonical_json( - pdu: &RawJsonValue, - room_version_id: &RoomVersionId, + pdu: &RawJsonValue, room_version_id: &RoomVersionId, ) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - warn!("Error parsing incoming event {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + warn!("Error parsing incoming event {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; - let event_id = format!( - "${}", - // Anything higher than version3 behaves the same - ruma::signatures::reference_hash(&value, room_version_id) - .expect("ruma can calculate reference hashes") - ) - .try_into() - .expect("ruma's reference hashes are valid event ids"); + let event_id = format!( + "${}", + // Anything higher than version3 behaves the same + ruma::signatures::reference_hash(&value, room_version_id).expect("ruma can calculate reference hashes") + ) + .try_into() + .expect("ruma's reference hashes are valid event ids"); - Ok((event_id, value)) + Ok((event_id, value)) } /// Build the start of a PDU in order to add it to the Database. #[derive(Debug, Deserialize)] pub struct PduBuilder { - #[serde(rename = "type")] - pub event_type: TimelineEventType, - pub content: Box, - pub unsigned: Option>, - pub state_key: Option, - pub redacts: Option>, + #[serde(rename = "type")] + pub event_type: TimelineEventType, + pub content: Box, + pub unsigned: Option>, + pub state_key: Option, + pub redacts: Option>, } diff --git a/src/service/pusher/data.rs b/src/service/pusher/data.rs index 2062f567..b58cd3fc 100644 --- a/src/service/pusher/data.rs +++ b/src/service/pusher/data.rs @@ -1,16 +1,16 @@ -use crate::Result; use ruma::{ - api::client::push::{set_pusher, Pusher}, - UserId, + api::client::push::{set_pusher, Pusher}, + UserId, }; +use crate::Result; + pub trait Data: Send + Sync { - fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>; + fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>; - fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result>; + fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result>; - fn get_pushers(&self, sender: &UserId) -> Result>; + fn get_pushers(&self, sender: &UserId) -> Result>; - fn get_pushkeys<'a>(&'a self, sender: &UserId) - -> Box> + 'a>; + fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box> + 'a>; } diff --git a/src/service/pusher/mod.rs b/src/service/pusher/mod.rs index 2131e918..fc1078a8 100644 --- a/src/service/pusher/mod.rs +++ b/src/service/pusher/mod.rs @@ -1,292 +1,236 @@ mod data; -pub use data::Data; -use ruma::{events::AnySyncTimelineEvent, push::PushConditionPowerLevelsCtx}; - -use crate::{services, Error, PduEvent, Result}; -use bytes::BytesMut; -use ruma::{ - api::{ - client::push::{set_pusher, Pusher, PusherKind}, - push_gateway::send_event_notification::{ - self, - v1::{Device, Notification, NotificationCounts, NotificationPriority}, - }, - IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, - }, - events::{room::power_levels::RoomPowerLevelsEventContent, StateEventType, TimelineEventType}, - push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, - serde::Raw, - uint, RoomId, UInt, UserId, -}; - use std::{fmt::Debug, mem}; + +use bytes::BytesMut; +pub use data::Data; +use ruma::{ + api::{ + client::push::{set_pusher, Pusher, PusherKind}, + push_gateway::send_event_notification::{ + self, + v1::{Device, Notification, NotificationCounts, NotificationPriority}, + }, + IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken, + }, + events::{ + room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent, StateEventType, TimelineEventType, + }, + push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak}, + serde::Raw, + uint, RoomId, UInt, UserId, +}; use tracing::{info, warn}; +use crate::{services, Error, PduEvent, Result}; + pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { - self.db.set_pusher(sender, pusher) - } + pub fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()> { + self.db.set_pusher(sender, pusher) + } - pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { - self.db.get_pusher(sender, pushkey) - } + pub fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result> { + self.db.get_pusher(sender, pushkey) + } - pub fn get_pushers(&self, sender: &UserId) -> Result> { - self.db.get_pushers(sender) - } + pub fn get_pushers(&self, sender: &UserId) -> Result> { self.db.get_pushers(sender) } - pub fn get_pushkeys(&self, sender: &UserId) -> Box>> { - self.db.get_pushkeys(sender) - } + pub fn get_pushkeys(&self, sender: &UserId) -> Box>> { + self.db.get_pushkeys(sender) + } - #[tracing::instrument(skip(self, destination, request))] - pub async fn send_request( - &self, - destination: &str, - request: T, - ) -> Result - where - T: OutgoingRequest + Debug, - { - let destination = destination.replace(services().globals.notification_push_path(), ""); + #[tracing::instrument(skip(self, destination, request))] + pub async fn send_request(&self, destination: &str, request: T) -> Result + where + T: OutgoingRequest + Debug, + { + let destination = destination.replace(services().globals.notification_push_path(), ""); - let http_request = request - .try_into_http_request::( - &destination, - SendAccessToken::IfRequired(""), - &[MatrixVersion::V1_0], - ) - .map_err(|e| { - warn!("Failed to find destination {}: {}", destination, e); - Error::BadServerResponse("Invalid destination") - })? - .map(bytes::BytesMut::freeze); + let http_request = request + .try_into_http_request::(&destination, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_0]) + .map_err(|e| { + warn!("Failed to find destination {}: {}", destination, e); + Error::BadServerResponse("Invalid destination") + })? + .map(bytes::BytesMut::freeze); - let reqwest_request = reqwest::Request::try_from(http_request)?; + let reqwest_request = reqwest::Request::try_from(http_request)?; - // TODO: we could keep this very short and let expo backoff do it's thing... - //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); + // TODO: we could keep this very short and let expo backoff do it's thing... + //*reqwest_request.timeout_mut() = Some(Duration::from_secs(5)); - let url = reqwest_request.url().clone(); - let response = services() - .globals - .default_client() - .execute(reqwest_request) - .await; + let url = reqwest_request.url().clone(); + let response = services().globals.default_client().execute(reqwest_request).await; - match response { - Ok(mut response) => { - // reqwest::Response -> http::Response conversion - let status = response.status(); - let mut http_response_builder = http::Response::builder() - .status(status) - .version(response.version()); - mem::swap( - response.headers_mut(), - http_response_builder - .headers_mut() - .expect("http::response::Builder is usable"), - ); + match response { + Ok(mut response) => { + // reqwest::Response -> http::Response conversion + let status = response.status(); + let mut http_response_builder = http::Response::builder().status(status).version(response.version()); + mem::swap( + response.headers_mut(), + http_response_builder.headers_mut().expect("http::response::Builder is usable"), + ); - let body = response.bytes().await.unwrap_or_else(|e| { - warn!("server error {}", e); - Vec::new().into() - }); // TODO: handle timeout + let body = response.bytes().await.unwrap_or_else(|e| { + warn!("server error {}", e); + Vec::new().into() + }); // TODO: handle timeout - if !status.is_success() { - info!( - "Push gateway returned bad response {} {}\n{}\n{:?}", - destination, - status, - url, - crate::utils::string_from_bytes(&body) - ); - } + if !status.is_success() { + info!( + "Push gateway returned bad response {} {}\n{}\n{:?}", + destination, + status, + url, + crate::utils::string_from_bytes(&body) + ); + } - let response = T::IncomingResponse::try_from_http_response( - http_response_builder - .body(body) - .expect("reqwest body is valid http body"), - ); - response.map_err(|_| { - info!( - "Push gateway returned invalid response bytes {}\n{}", - destination, url - ); - Error::BadServerResponse("Push gateway returned bad response.") - }) - } - Err(e) => { - warn!("Could not send request to pusher {}: {}", destination, e); - Err(e.into()) - } - } - } + let response = T::IncomingResponse::try_from_http_response( + http_response_builder.body(body).expect("reqwest body is valid http body"), + ); + response.map_err(|_| { + info!("Push gateway returned invalid response bytes {}\n{}", destination, url); + Error::BadServerResponse("Push gateway returned bad response.") + }) + }, + Err(e) => { + warn!("Could not send request to pusher {}: {}", destination, e); + Err(e.into()) + }, + } + } - #[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))] - pub async fn send_push_notice( - &self, - user: &UserId, - unread: UInt, - pusher: &Pusher, - ruleset: Ruleset, - pdu: &PduEvent, - ) -> Result<()> { - let mut notify = None; - let mut tweaks = Vec::new(); + #[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))] + pub async fn send_push_notice( + &self, user: &UserId, unread: UInt, pusher: &Pusher, ruleset: Ruleset, pdu: &PduEvent, + ) -> Result<()> { + let mut notify = None; + let mut tweaks = Vec::new(); - let power_levels: RoomPowerLevelsEventContent = services() - .rooms - .state_accessor - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); + let power_levels: RoomPowerLevelsEventContent = services() + .rooms + .state_accessor + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); - for action in self.get_actions( - user, - &ruleset, - &power_levels, - &pdu.to_sync_room_event(), - &pdu.room_id, - )? { - let n = match action { - Action::Notify => true, - Action::SetTweak(tweak) => { - tweaks.push(tweak.clone()); - continue; - } - _ => false, - }; + for action in self.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)? { + let n = match action { + Action::Notify => true, + Action::SetTweak(tweak) => { + tweaks.push(tweak.clone()); + continue; + }, + _ => false, + }; - if notify.is_some() { - return Err(Error::bad_database( - r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, - )); - } + if notify.is_some() { + return Err(Error::bad_database( + r#"Malformed pushrule contains more than one of these actions: ["dont_notify", "notify", "coalesce"]"#, + )); + } - notify = Some(n); - } + notify = Some(n); + } - if notify == Some(true) { - self.send_notice(unread, pusher, tweaks, pdu).await?; - } - // Else the event triggered no actions + if notify == Some(true) { + self.send_notice(unread, pusher, tweaks, pdu).await?; + } + // Else the event triggered no actions - Ok(()) - } + Ok(()) + } - #[tracing::instrument(skip(self, user, ruleset, pdu))] - pub fn get_actions<'a>( - &self, - user: &UserId, - ruleset: &'a Ruleset, - power_levels: &RoomPowerLevelsEventContent, - pdu: &Raw, - room_id: &RoomId, - ) -> Result<&'a [Action]> { - let power_levels = PushConditionPowerLevelsCtx { - users: power_levels.users.clone(), - users_default: power_levels.users_default, - notifications: power_levels.notifications.clone(), - }; + #[tracing::instrument(skip(self, user, ruleset, pdu))] + pub fn get_actions<'a>( + &self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent, + pdu: &Raw, room_id: &RoomId, + ) -> Result<&'a [Action]> { + let power_levels = PushConditionPowerLevelsCtx { + users: power_levels.users.clone(), + users_default: power_levels.users_default, + notifications: power_levels.notifications.clone(), + }; - let ctx = PushConditionRoomCtx { - room_id: room_id.to_owned(), - member_count: UInt::from( - services() - .rooms - .state_cache - .room_joined_count(room_id)? - .unwrap_or(1) as u32, - ), - user_id: user.to_owned(), - user_display_name: services() - .users - .displayname(user)? - .unwrap_or_else(|| user.localpart().to_owned()), - power_levels: Some(power_levels), - }; + let ctx = PushConditionRoomCtx { + room_id: room_id.to_owned(), + member_count: UInt::from(services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(1) as u32), + user_id: user.to_owned(), + user_display_name: services().users.displayname(user)?.unwrap_or_else(|| user.localpart().to_owned()), + power_levels: Some(power_levels), + }; - Ok(ruleset.get_actions(pdu, &ctx)) - } + Ok(ruleset.get_actions(pdu, &ctx)) + } - #[tracing::instrument(skip(self, unread, pusher, tweaks, event))] - async fn send_notice( - &self, - unread: UInt, - pusher: &Pusher, - tweaks: Vec, - event: &PduEvent, - ) -> Result<()> { - // TODO: email - match &pusher.kind { - PusherKind::Http(http) => { - // TODO: - // Two problems with this - // 1. if "event_id_only" is the only format kind it seems we should never add more info - // 2. can pusher/devices have conflicting formats - let event_id_only = http.format == Some(PushFormat::EventIdOnly); + #[tracing::instrument(skip(self, unread, pusher, tweaks, event))] + async fn send_notice(&self, unread: UInt, pusher: &Pusher, tweaks: Vec, event: &PduEvent) -> Result<()> { + // TODO: email + match &pusher.kind { + PusherKind::Http(http) => { + // TODO: + // Two problems with this + // 1. if "event_id_only" is the only format kind it seems we should never add + // more info + // 2. can pusher/devices have conflicting formats + let event_id_only = http.format == Some(PushFormat::EventIdOnly); - let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone()); - device.data.default_payload = http.default_payload.clone(); - device.data.format = http.format.clone(); + let mut device = Device::new(pusher.ids.app_id.clone(), pusher.ids.pushkey.clone()); + device.data.default_payload = http.default_payload.clone(); + device.data.format = http.format.clone(); - // Tweaks are only added if the format is NOT event_id_only - if !event_id_only { - device.tweaks = tweaks.clone(); - } + // Tweaks are only added if the format is NOT event_id_only + if !event_id_only { + device.tweaks = tweaks.clone(); + } - let d = vec![device]; - let mut notifi = Notification::new(d); + let d = vec![device]; + let mut notifi = Notification::new(d); - notifi.prio = NotificationPriority::Low; - notifi.event_id = Some((*event.event_id).to_owned()); - notifi.room_id = Some((*event.room_id).to_owned()); - // TODO: missed calls - notifi.counts = NotificationCounts::new(unread, uint!(0)); + notifi.prio = NotificationPriority::Low; + notifi.event_id = Some((*event.event_id).to_owned()); + notifi.room_id = Some((*event.room_id).to_owned()); + // TODO: missed calls + notifi.counts = NotificationCounts::new(unread, uint!(0)); - if event.kind == TimelineEventType::RoomEncrypted - || tweaks - .iter() - .any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) - { - notifi.prio = NotificationPriority::High; - } + if event.kind == TimelineEventType::RoomEncrypted + || tweaks.iter().any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_))) + { + notifi.prio = NotificationPriority::High; + } - if event_id_only { - self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) - .await?; - } else { - notifi.sender = Some(event.sender.clone()); - notifi.event_type = Some(event.kind.clone()); - notifi.content = serde_json::value::to_raw_value(&event.content).ok(); + if event_id_only { + self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?; + } else { + notifi.sender = Some(event.sender.clone()); + notifi.event_type = Some(event.kind.clone()); + notifi.content = serde_json::value::to_raw_value(&event.content).ok(); - if event.kind == TimelineEventType::RoomMember { - notifi.user_is_target = - event.state_key.as_deref() == Some(event.sender.as_str()); - } + if event.kind == TimelineEventType::RoomMember { + notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str()); + } - notifi.sender_display_name = services().users.displayname(&event.sender)?; + notifi.sender_display_name = services().users.displayname(&event.sender)?; - notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?; + notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?; - self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)) - .await?; - } + self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?; + } - Ok(()) - } - // TODO: Handle email - //PusherKind::Email(_) => Ok(()), - _ => Ok(()), - } - } + Ok(()) + }, + // TODO: Handle email + //PusherKind::Email(_) => Ok(()), + _ => Ok(()), + } + } } diff --git a/src/service/rooms/alias/data.rs b/src/service/rooms/alias/data.rs index e2647ffc..095d6e66 100644 --- a/src/service/rooms/alias/data.rs +++ b/src/service/rooms/alias/data.rs @@ -1,24 +1,22 @@ -use crate::Result; use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; +use crate::Result; + pub trait Data: Send + Sync { - /// Creates or updates the alias to the given room id. - fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>; + /// Creates or updates the alias to the given room id. + fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>; - /// Forgets about an alias. Returns an error if the alias did not exist. - fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>; + /// Forgets about an alias. Returns an error if the alias did not exist. + fn remove_alias(&self, alias: &RoomAliasId) -> Result<()>; - /// Looks up the roomid for the given alias. - fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result>; + /// Looks up the roomid for the given alias. + fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result>; - /// Returns all local aliases that point to the given room - fn local_aliases_for_room<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a>; + /// Returns all local aliases that point to the given room + fn local_aliases_for_room<'a>( + &'a self, room_id: &RoomId, + ) -> Box> + 'a>; - /// Returns all local aliases on the server - fn all_local_aliases<'a>( - &'a self, - ) -> Box> + 'a>; + /// Returns all local aliases on the server + fn all_local_aliases<'a>(&'a self) -> Box> + 'a>; } diff --git a/src/service/rooms/alias/mod.rs b/src/service/rooms/alias/mod.rs index 34a5732b..a52faefe 100644 --- a/src/service/rooms/alias/mod.rs +++ b/src/service/rooms/alias/mod.rs @@ -1,42 +1,35 @@ mod data; pub use data::Data; - -use crate::Result; use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId}; +use crate::Result; + pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - #[tracing::instrument(skip(self))] - pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { - self.db.set_alias(alias, room_id) - } + #[tracing::instrument(skip(self))] + pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { self.db.set_alias(alias, room_id) } - #[tracing::instrument(skip(self))] - pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { - self.db.remove_alias(alias) - } + #[tracing::instrument(skip(self))] + pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { self.db.remove_alias(alias) } - #[tracing::instrument(skip(self))] - pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { - self.db.resolve_local_alias(alias) - } + #[tracing::instrument(skip(self))] + pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result> { + self.db.resolve_local_alias(alias) + } - #[tracing::instrument(skip(self))] - pub fn local_aliases_for_room<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a> { - self.db.local_aliases_for_room(room_id) - } + #[tracing::instrument(skip(self))] + pub fn local_aliases_for_room<'a>( + &'a self, room_id: &RoomId, + ) -> Box> + 'a> { + self.db.local_aliases_for_room(room_id) + } - #[tracing::instrument(skip(self))] - pub fn all_local_aliases<'a>( - &'a self, - ) -> Box> + 'a> { - self.db.all_local_aliases() - } + #[tracing::instrument(skip(self))] + pub fn all_local_aliases<'a>(&'a self) -> Box> + 'a> { + self.db.all_local_aliases() + } } diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index e8c379fc..c83b4eb0 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -1,11 +1,8 @@ -use crate::Result; use std::{collections::HashSet, sync::Arc}; +use crate::Result; + pub trait Data: Send + Sync { - fn get_cached_eventid_authchain( - &self, - shorteventid: &[u64], - ) -> Result>>>; - fn cache_auth_chain(&self, shorteventid: Vec, auth_chain: Arc>) - -> Result<()>; + fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result>>>; + fn cache_auth_chain(&self, shorteventid: Vec, auth_chain: Arc>) -> Result<()>; } diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index da1944e2..2aa80442 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -1,7 +1,7 @@ mod data; use std::{ - collections::{BTreeSet, HashSet}, - sync::Arc, + collections::{BTreeSet, HashSet}, + sync::Arc, }; pub use data::Data; @@ -11,151 +11,130 @@ use tracing::{debug, error, warn}; use crate::{services, Error, Result}; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>>> { - self.db.get_cached_eventid_authchain(key) - } + pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result>>> { + self.db.get_cached_eventid_authchain(key) + } - #[tracing::instrument(skip(self))] - pub fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { - self.db.cache_auth_chain(key, auth_chain) - } + #[tracing::instrument(skip(self))] + pub fn cache_auth_chain(&self, key: Vec, auth_chain: Arc>) -> Result<()> { + self.db.cache_auth_chain(key, auth_chain) + } - #[tracing::instrument(skip(self, starting_events))] - pub async fn get_auth_chain<'a>( - &self, - room_id: &RoomId, - starting_events: Vec>, - ) -> Result> + 'a> { - const NUM_BUCKETS: usize = 50; + #[tracing::instrument(skip(self, starting_events))] + pub async fn get_auth_chain<'a>( + &self, room_id: &RoomId, starting_events: Vec>, + ) -> Result> + 'a> { + const NUM_BUCKETS: usize = 50; - let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; + let mut buckets = vec![BTreeSet::new(); NUM_BUCKETS]; - let mut i = 0; - for id in starting_events { - let short = services().rooms.short.get_or_create_shorteventid(&id)?; - let bucket_id = (short % NUM_BUCKETS as u64) as usize; - buckets[bucket_id].insert((short, id.clone())); - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } + let mut i = 0; + for id in starting_events { + let short = services().rooms.short.get_or_create_shorteventid(&id)?; + let bucket_id = (short % NUM_BUCKETS as u64) as usize; + buckets[bucket_id].insert((short, id.clone())); + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + } - let mut full_auth_chain = HashSet::new(); + let mut full_auth_chain = HashSet::new(); - let mut hits = 0; - let mut misses = 0; - for chunk in buckets { - if chunk.is_empty() { - continue; - } + let mut hits = 0; + let mut misses = 0; + for chunk in buckets { + if chunk.is_empty() { + continue; + } - let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); - if let Some(cached) = services() - .rooms - .auth_chain - .get_cached_eventid_authchain(&chunk_key)? - { - hits += 1; - full_auth_chain.extend(cached.iter().copied()); - continue; - } - misses += 1; + let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); + if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? { + hits += 1; + full_auth_chain.extend(cached.iter().copied()); + continue; + } + misses += 1; - let mut chunk_cache = HashSet::new(); - let mut hits2 = 0; - let mut misses2 = 0; - let mut i = 0; - for (sevent_id, event_id) in chunk { - if let Some(cached) = services() - .rooms - .auth_chain - .get_cached_eventid_authchain(&[sevent_id])? - { - hits2 += 1; - chunk_cache.extend(cached.iter().copied()); - } else { - misses2 += 1; - let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); - services() - .rooms - .auth_chain - .cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; - debug!( - event_id = ?event_id, - chain_length = ?auth_chain.len(), - "Cache missed event" - ); - chunk_cache.extend(auth_chain.iter()); + let mut chunk_cache = HashSet::new(); + let mut hits2 = 0; + let mut misses2 = 0; + let mut i = 0; + for (sevent_id, event_id) in chunk { + if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? { + hits2 += 1; + chunk_cache.extend(cached.iter().copied()); + } else { + misses2 += 1; + let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?); + services().rooms.auth_chain.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?; + debug!( + event_id = ?event_id, + chain_length = ?auth_chain.len(), + "Cache missed event" + ); + chunk_cache.extend(auth_chain.iter()); - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - }; - } - debug!( - chunk_cache_length = ?chunk_cache.len(), - hits = ?hits2, - misses = ?misses2, - "Chunk missed", - ); - let chunk_cache = Arc::new(chunk_cache); - services() - .rooms - .auth_chain - .cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; - full_auth_chain.extend(chunk_cache.iter()); - } + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + }; + } + debug!( + chunk_cache_length = ?chunk_cache.len(), + hits = ?hits2, + misses = ?misses2, + "Chunk missed", + ); + let chunk_cache = Arc::new(chunk_cache); + services().rooms.auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?; + full_auth_chain.extend(chunk_cache.iter()); + } - debug!( - chain_length = ?full_auth_chain.len(), - hits = ?hits, - misses = ?misses, - "Auth chain stats", - ); + debug!( + chain_length = ?full_auth_chain.len(), + hits = ?hits, + misses = ?misses, + "Auth chain stats", + ); - Ok(full_auth_chain - .into_iter() - .filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) - } + Ok(full_auth_chain.into_iter().filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok())) + } - #[tracing::instrument(skip(self, event_id))] - fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { - let mut todo = vec![Arc::from(event_id)]; - let mut found = HashSet::new(); + #[tracing::instrument(skip(self, event_id))] + fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + let mut todo = vec![Arc::from(event_id)]; + let mut found = HashSet::new(); - while let Some(event_id) = todo.pop() { - match services().rooms.timeline.get_pdu(&event_id) { - Ok(Some(pdu)) => { - if pdu.room_id != room_id { - return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); - } - for auth_event in &pdu.auth_events { - let sauthevent = services() - .rooms - .short - .get_or_create_shorteventid(auth_event)?; + while let Some(event_id) = todo.pop() { + match services().rooms.timeline.get_pdu(&event_id) { + Ok(Some(pdu)) => { + if pdu.room_id != room_id { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db")); + } + for auth_event in &pdu.auth_events { + let sauthevent = services().rooms.short.get_or_create_shorteventid(auth_event)?; - if !found.contains(&sauthevent) { - found.insert(sauthevent); - todo.push(auth_event.clone()); - } - } - } - Ok(None) => { - warn!(?event_id, "Could not find pdu mentioned in auth events"); - } - Err(error) => { - error!(?event_id, ?error, "Could not load event in auth chain"); - } - } - } + if !found.contains(&sauthevent) { + found.insert(sauthevent); + todo.push(auth_event.clone()); + } + } + }, + Ok(None) => { + warn!(?event_id, "Could not find pdu mentioned in auth events"); + }, + Err(error) => { + error!(?event_id, ?error, "Could not load event in auth chain"); + }, + } + } - Ok(found) - } + Ok(found) + } } diff --git a/src/service/rooms/directory/data.rs b/src/service/rooms/directory/data.rs index aca731ce..691b8604 100644 --- a/src/service/rooms/directory/data.rs +++ b/src/service/rooms/directory/data.rs @@ -1,16 +1,17 @@ -use crate::Result; use ruma::{OwnedRoomId, RoomId}; +use crate::Result; + pub trait Data: Send + Sync { - /// Adds the room to the public room directory - fn set_public(&self, room_id: &RoomId) -> Result<()>; + /// Adds the room to the public room directory + fn set_public(&self, room_id: &RoomId) -> Result<()>; - /// Removes the room from the public room directory. - fn set_not_public(&self, room_id: &RoomId) -> Result<()>; + /// Removes the room from the public room directory. + fn set_not_public(&self, room_id: &RoomId) -> Result<()>; - /// Returns true if the room is in the public room directory. - fn is_public_room(&self, room_id: &RoomId) -> Result; + /// Returns true if the room is in the public room directory. + fn is_public_room(&self, room_id: &RoomId) -> Result; - /// Returns the unsorted public room directory - fn public_rooms<'a>(&'a self) -> Box> + 'a>; + /// Returns the unsorted public room directory + fn public_rooms<'a>(&'a self) -> Box> + 'a>; } diff --git a/src/service/rooms/directory/mod.rs b/src/service/rooms/directory/mod.rs index 10f782bb..0efc365c 100644 --- a/src/service/rooms/directory/mod.rs +++ b/src/service/rooms/directory/mod.rs @@ -6,27 +6,19 @@ use ruma::{OwnedRoomId, RoomId}; use crate::Result; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - #[tracing::instrument(skip(self))] - pub fn set_public(&self, room_id: &RoomId) -> Result<()> { - self.db.set_public(room_id) - } + #[tracing::instrument(skip(self))] + pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) } - #[tracing::instrument(skip(self))] - pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { - self.db.set_not_public(room_id) - } + #[tracing::instrument(skip(self))] + pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) } - #[tracing::instrument(skip(self))] - pub fn is_public_room(&self, room_id: &RoomId) -> Result { - self.db.is_public_room(room_id) - } + #[tracing::instrument(skip(self))] + pub fn is_public_room(&self, room_id: &RoomId) -> Result { self.db.is_public_room(room_id) } - #[tracing::instrument(skip(self))] - pub fn public_rooms(&self) -> impl Iterator> + '_ { - self.db.public_rooms() - } + #[tracing::instrument(skip(self))] + pub fn public_rooms(&self) -> impl Iterator> + '_ { self.db.public_rooms() } } diff --git a/src/service/rooms/edus/mod.rs b/src/service/rooms/edus/mod.rs index cf7a3591..593265cb 100644 --- a/src/service/rooms/edus/mod.rs +++ b/src/service/rooms/edus/mod.rs @@ -5,7 +5,7 @@ pub mod typing; pub trait Data: presence::Data + read_receipt::Data + typing::Data + 'static {} pub struct Service { - pub presence: presence::Service, - pub read_receipt: read_receipt::Service, - pub typing: typing::Service, + pub presence: presence::Service, + pub read_receipt: read_receipt::Service, + pub typing: typing::Service, } diff --git a/src/service/rooms/edus/presence/data.rs b/src/service/rooms/edus/presence/data.rs index c13ac6f4..6b7ad4c2 100644 --- a/src/service/rooms/edus/presence/data.rs +++ b/src/service/rooms/edus/presence/data.rs @@ -1,33 +1,27 @@ +use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId}; + use crate::Result; -use ruma::{ - events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId, -}; pub trait Data: Send + Sync { - /// Returns the latest presence event for the given user in the given room. - fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + /// Returns the latest presence event for the given user in the given room. + fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result>; - /// Pings the presence of the given user in the given room, setting the specified state. - fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()>; + /// Pings the presence of the given user in the given room, setting the + /// specified state. + fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()>; - /// Adds a presence event which will be saved until a new event replaces it. - fn set_presence( - &self, - room_id: &RoomId, - user_id: &UserId, - presence_state: PresenceState, - currently_active: Option, - last_active_ago: Option, - status_msg: Option, - ) -> Result<()>; + /// Adds a presence event which will be saved until a new event replaces it. + fn set_presence( + &self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option, + last_active_ago: Option, status_msg: Option, + ) -> Result<()>; - /// Removes the presence record for the given user from the database. - fn remove_presence(&self, user_id: &UserId) -> Result<()>; + /// Removes the presence record for the given user from the database. + fn remove_presence(&self, user_id: &UserId) -> Result<()>; - /// Returns the most recent presence updates that happened after the event with id `since`. - fn presence_since<'a>( - &'a self, - room_id: &RoomId, - since: u64, - ) -> Box + 'a>; + /// Returns the most recent presence updates that happened after the event + /// with id `since`. + fn presence_since<'a>( + &'a self, room_id: &RoomId, since: u64, + ) -> Box + 'a>; } diff --git a/src/service/rooms/edus/presence/mod.rs b/src/service/rooms/edus/presence/mod.rs index fb929c45..f2b335b3 100644 --- a/src/service/rooms/edus/presence/mod.rs +++ b/src/service/rooms/edus/presence/mod.rs @@ -5,9 +5,9 @@ use std::time::Duration; pub use data::Data; use futures_util::{stream::FuturesUnordered, StreamExt}; use ruma::{ - events::presence::{PresenceEvent, PresenceEventContent}, - presence::PresenceState, - OwnedUserId, RoomId, UInt, UserId, + events::presence::{PresenceEvent, PresenceEventContent}, + presence::PresenceState, + OwnedUserId, RoomId, UInt, UserId, }; use serde::{Deserialize, Serialize}; use tokio::{sync::mpsc, time::sleep}; @@ -15,197 +15,164 @@ use tracing::debug; use crate::{services, utils, Error, Result}; -/// Represents data required to be kept in order to implement the presence specification. +/// Represents data required to be kept in order to implement the presence +/// specification. #[derive(Serialize, Deserialize, Debug, Clone)] pub struct Presence { - pub state: PresenceState, - pub currently_active: bool, - pub last_active_ts: u64, - pub last_count: u64, - pub status_msg: Option, + pub state: PresenceState, + pub currently_active: bool, + pub last_active_ts: u64, + pub last_count: u64, + pub status_msg: Option, } impl Presence { - pub fn new( - state: PresenceState, - currently_active: bool, - last_active_ts: u64, - last_count: u64, - status_msg: Option, - ) -> Self { - Self { - state, - currently_active, - last_active_ts, - last_count, - status_msg, - } - } + pub fn new( + state: PresenceState, currently_active: bool, last_active_ts: u64, last_count: u64, status_msg: Option, + ) -> Self { + Self { + state, + currently_active, + last_active_ts, + last_count, + status_msg, + } + } - pub fn from_json_bytes(bytes: &[u8]) -> Result { - serde_json::from_slice(bytes) - .map_err(|_| Error::bad_database("Invalid presence data in database")) - } + pub fn from_json_bytes(bytes: &[u8]) -> Result { + serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database")) + } - pub fn to_json_bytes(&self) -> Result> { - serde_json::to_vec(self) - .map_err(|_| Error::bad_database("Could not serialize Presence to JSON")) - } + pub fn to_json_bytes(&self) -> Result> { + serde_json::to_vec(self).map_err(|_| Error::bad_database("Could not serialize Presence to JSON")) + } - /// Creates a PresenceEvent from available data. - pub fn to_presence_event(&self, user_id: &UserId) -> Result { - let now = utils::millis_since_unix_epoch(); - let last_active_ago = if self.currently_active { - None - } else { - Some(UInt::new_saturating( - now.saturating_sub(self.last_active_ts), - )) - }; + /// Creates a PresenceEvent from available data. + pub fn to_presence_event(&self, user_id: &UserId) -> Result { + let now = utils::millis_since_unix_epoch(); + let last_active_ago = if self.currently_active { + None + } else { + Some(UInt::new_saturating(now.saturating_sub(self.last_active_ts))) + }; - Ok(PresenceEvent { - sender: user_id.to_owned(), - content: PresenceEventContent { - presence: self.state.clone(), - status_msg: self.status_msg.clone(), - currently_active: Some(self.currently_active), - last_active_ago, - displayname: services().users.displayname(user_id)?, - avatar_url: services().users.avatar_url(user_id)?, - }, - }) - } + Ok(PresenceEvent { + sender: user_id.to_owned(), + content: PresenceEventContent { + presence: self.state.clone(), + status_msg: self.status_msg.clone(), + currently_active: Some(self.currently_active), + last_active_ago, + displayname: services().users.displayname(user_id)?, + avatar_url: services().users.avatar_url(user_id)?, + }, + }) + } } pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - /// Returns the latest presence event for the given user in the given room. - pub fn get_presence( - &self, - room_id: &RoomId, - user_id: &UserId, - ) -> Result> { - self.db.get_presence(room_id, user_id) - } + /// Returns the latest presence event for the given user in the given room. + pub fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + self.db.get_presence(room_id, user_id) + } - /// Pings the presence of the given user in the given room, setting the specified state. - pub fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> { - self.db.ping_presence(user_id, new_state) - } + /// Pings the presence of the given user in the given room, setting the + /// specified state. + pub fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> { + self.db.ping_presence(user_id, new_state) + } - /// Adds a presence event which will be saved until a new event replaces it. - pub fn set_presence( - &self, - room_id: &RoomId, - user_id: &UserId, - presence_state: PresenceState, - currently_active: Option, - last_active_ago: Option, - status_msg: Option, - ) -> Result<()> { - self.db.set_presence( - room_id, - user_id, - presence_state, - currently_active, - last_active_ago, - status_msg, - ) - } + /// Adds a presence event which will be saved until a new event replaces it. + pub fn set_presence( + &self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option, + last_active_ago: Option, status_msg: Option, + ) -> Result<()> { + self.db.set_presence(room_id, user_id, presence_state, currently_active, last_active_ago, status_msg) + } - /// Removes the presence record for the given user from the database. - pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { - self.db.remove_presence(user_id) - } + /// Removes the presence record for the given user from the database. + pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) } - /// Returns the most recent presence updates that happened after the event with id `since`. - pub fn presence_since( - &self, - room_id: &RoomId, - since: u64, - ) -> Box> { - self.db.presence_since(room_id, since) - } + /// Returns the most recent presence updates that happened after the event + /// with id `since`. + pub fn presence_since( + &self, room_id: &RoomId, since: u64, + ) -> Box> { + self.db.presence_since(room_id, since) + } } pub async fn presence_handler( - mut presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>, + mut presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>, ) -> Result<()> { - let mut presence_timers = FuturesUnordered::new(); + let mut presence_timers = FuturesUnordered::new(); - loop { - debug!("Number of presence timers: {}", presence_timers.len()); + loop { + debug!("Number of presence timers: {}", presence_timers.len()); - tokio::select! { - Some((user_id, timeout)) = presence_timer_receiver.recv() => { - debug!("Adding timer for user '{user_id}': Timeout {timeout:?}"); - presence_timers.push(presence_timer(user_id, timeout)); - } + tokio::select! { + Some((user_id, timeout)) = presence_timer_receiver.recv() => { + debug!("Adding timer for user '{user_id}': Timeout {timeout:?}"); + presence_timers.push(presence_timer(user_id, timeout)); + } - Some(user_id) = presence_timers.next() => { - process_presence_timer(user_id)?; - } - } - } + Some(user_id) = presence_timers.next() => { + process_presence_timer(user_id)?; + } + } + } } async fn presence_timer(user_id: OwnedUserId, timeout: Duration) -> OwnedUserId { - sleep(timeout).await; + sleep(timeout).await; - user_id + user_id } fn process_presence_timer(user_id: OwnedUserId) -> Result<()> { - let idle_timeout = services().globals.config.presence_idle_timeout_s * 1_000; - let offline_timeout = services().globals.config.presence_offline_timeout_s * 1_000; + let idle_timeout = services().globals.config.presence_idle_timeout_s * 1_000; + let offline_timeout = services().globals.config.presence_offline_timeout_s * 1_000; - let mut presence_state = PresenceState::Offline; - let mut last_active_ago = None; - let mut status_msg = None; + let mut presence_state = PresenceState::Offline; + let mut last_active_ago = None; + let mut status_msg = None; - for room_id in services().rooms.state_cache.rooms_joined(&user_id) { - let presence_event = services() - .rooms - .edus - .presence - .get_presence(&room_id?, &user_id)?; + for room_id in services().rooms.state_cache.rooms_joined(&user_id) { + let presence_event = services().rooms.edus.presence.get_presence(&room_id?, &user_id)?; - if let Some(presence_event) = presence_event { - presence_state = presence_event.content.presence; - last_active_ago = presence_event.content.last_active_ago; - status_msg = presence_event.content.status_msg; + if let Some(presence_event) = presence_event { + presence_state = presence_event.content.presence; + last_active_ago = presence_event.content.last_active_ago; + status_msg = presence_event.content.status_msg; - break; - } - } + break; + } + } - let new_state = match (&presence_state, last_active_ago.map(u64::from)) { - (PresenceState::Online, Some(ago)) if ago >= idle_timeout => { - Some(PresenceState::Unavailable) - } - (PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => { - Some(PresenceState::Offline) - } - _ => None, - }; + let new_state = match (&presence_state, last_active_ago.map(u64::from)) { + (PresenceState::Online, Some(ago)) if ago >= idle_timeout => Some(PresenceState::Unavailable), + (PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => Some(PresenceState::Offline), + _ => None, + }; - debug!("Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}"); + debug!("Processed presence timer for user '{user_id}': Old state = {presence_state}, New state = {new_state:?}"); - if let Some(new_state) = new_state { - for room_id in services().rooms.state_cache.rooms_joined(&user_id) { - services().rooms.edus.presence.set_presence( - &room_id?, - &user_id, - new_state.clone(), - Some(false), - last_active_ago, - status_msg.clone(), - )?; - } - } + if let Some(new_state) = new_state { + for room_id in services().rooms.state_cache.rooms_joined(&user_id) { + services().rooms.edus.presence.set_presence( + &room_id?, + &user_id, + new_state.clone(), + Some(false), + last_active_ago, + status_msg.clone(), + )?; + } + } - Ok(()) + Ok(()) } diff --git a/src/service/rooms/edus/read_receipt/data.rs b/src/service/rooms/edus/read_receipt/data.rs index 29b4a986..dcb550f8 100644 --- a/src/service/rooms/edus/read_receipt/data.rs +++ b/src/service/rooms/edus/read_receipt/data.rs @@ -1,32 +1,28 @@ -use crate::Result; use ruma::{ - events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, - serde::Raw, - OwnedUserId, RoomId, UserId, + events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent}, + serde::Raw, + OwnedUserId, RoomId, UserId, }; +use crate::Result; + type AnySyncEphemeralRoomEventIter<'a> = - Box)>> + 'a>; + Box)>> + 'a>; pub trait Data: Send + Sync { - /// Replaces the previous read receipt. - fn readreceipt_update( - &self, - user_id: &UserId, - room_id: &RoomId, - event: ReceiptEvent, - ) -> Result<()>; + /// Replaces the previous read receipt. + fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()>; - /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. - fn readreceipts_since(&self, room_id: &RoomId, since: u64) - -> AnySyncEphemeralRoomEventIter<'_>; + /// Returns an iterator over the most recent read_receipts in a room that + /// happened after the event with id `since`. + fn readreceipts_since(&self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'_>; - /// Sets a private read marker at `count`. - fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; + /// Sets a private read marker at `count`. + fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>; - /// Returns the private read marker. - fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + /// Returns the private read marker. + fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result>; - /// Returns the count of the last typing update in this room. - fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result; + /// Returns the count of the last typing update in this room. + fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result; } diff --git a/src/service/rooms/edus/read_receipt/mod.rs b/src/service/rooms/edus/read_receipt/mod.rs index c6035280..9b2bcd94 100644 --- a/src/service/rooms/edus/read_receipt/mod.rs +++ b/src/service/rooms/edus/read_receipt/mod.rs @@ -1,55 +1,43 @@ mod data; pub use data::Data; - -use crate::Result; use ruma::{events::receipt::ReceiptEvent, serde::Raw, OwnedUserId, RoomId, UserId}; +use crate::Result; + pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - /// Replaces the previous read receipt. - pub fn readreceipt_update( - &self, - user_id: &UserId, - room_id: &RoomId, - event: ReceiptEvent, - ) -> Result<()> { - self.db.readreceipt_update(user_id, room_id, event) - } + /// Replaces the previous read receipt. + pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> { + self.db.readreceipt_update(user_id, room_id, event) + } - /// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`. - #[tracing::instrument(skip(self))] - pub fn readreceipts_since<'a>( - &'a self, - room_id: &RoomId, - since: u64, - ) -> impl Iterator< - Item = Result<( - OwnedUserId, - u64, - Raw, - )>, - > + 'a { - self.db.readreceipts_since(room_id, since) - } + /// Returns an iterator over the most recent read_receipts in a room that + /// happened after the event with id `since`. + #[tracing::instrument(skip(self))] + pub fn readreceipts_since<'a>( + &'a self, room_id: &RoomId, since: u64, + ) -> impl Iterator)>> + 'a { + self.db.readreceipts_since(room_id, since) + } - /// Sets a private read marker at `count`. - #[tracing::instrument(skip(self))] - pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { - self.db.private_read_set(room_id, user_id, count) - } + /// Sets a private read marker at `count`. + #[tracing::instrument(skip(self))] + pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> { + self.db.private_read_set(room_id, user_id, count) + } - /// Returns the private read marker. - #[tracing::instrument(skip(self))] - pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.private_read_get(room_id, user_id) - } + /// Returns the private read marker. + #[tracing::instrument(skip(self))] + pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + self.db.private_read_get(room_id, user_id) + } - /// Returns the count of the last typing update in this room. - pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.last_privateread_update(user_id, room_id) - } + /// Returns the count of the last typing update in this room. + pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result { + self.db.last_privateread_update(user_id, room_id) + } } diff --git a/src/service/rooms/edus/typing/data.rs b/src/service/rooms/edus/typing/data.rs index 3b1eecfb..cadc30ee 100644 --- a/src/service/rooms/edus/typing/data.rs +++ b/src/service/rooms/edus/typing/data.rs @@ -1,21 +1,23 @@ -use crate::Result; -use ruma::{OwnedUserId, RoomId, UserId}; use std::collections::HashSet; +use ruma::{OwnedUserId, RoomId, UserId}; + +use crate::Result; + pub trait Data: Send + Sync { - /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is - /// called. - fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()>; + /// Sets a user as typing until the timeout timestamp is reached or + /// roomtyping_remove is called. + fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()>; - /// Removes a user from typing before the timeout is reached. - fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + /// Removes a user from typing before the timeout is reached. + fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - /// Makes sure that typing events with old timestamps get removed. - fn typings_maintain(&self, room_id: &RoomId) -> Result<()>; + /// Makes sure that typing events with old timestamps get removed. + fn typings_maintain(&self, room_id: &RoomId) -> Result<()>; - /// Returns the count of the last typing update in this room. - fn last_typing_update(&self, room_id: &RoomId) -> Result; + /// Returns the count of the last typing update in this room. + fn last_typing_update(&self, room_id: &RoomId) -> Result; - /// Returns all user ids currently typing. - fn typings_all(&self, room_id: &RoomId) -> Result>; + /// Returns all user ids currently typing. + fn typings_all(&self, room_id: &RoomId) -> Result>; } diff --git a/src/service/rooms/edus/typing/mod.rs b/src/service/rooms/edus/typing/mod.rs index 7d44f7d7..59231eb4 100644 --- a/src/service/rooms/edus/typing/mod.rs +++ b/src/service/rooms/edus/typing/mod.rs @@ -6,44 +6,41 @@ use ruma::{events::SyncEphemeralRoomEvent, RoomId, UserId}; use crate::Result; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - /// Sets a user as typing until the timeout timestamp is reached or roomtyping_remove is - /// called. - pub fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { - self.db.typing_add(user_id, room_id, timeout) - } + /// Sets a user as typing until the timeout timestamp is reached or + /// roomtyping_remove is called. + pub fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> { + self.db.typing_add(user_id, room_id, timeout) + } - /// Removes a user from typing before the timeout is reached. - pub fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.typing_remove(user_id, room_id) - } + /// Removes a user from typing before the timeout is reached. + pub fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + self.db.typing_remove(user_id, room_id) + } - /// Makes sure that typing events with old timestamps get removed. - fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { - self.db.typings_maintain(room_id) - } + /// Makes sure that typing events with old timestamps get removed. + fn typings_maintain(&self, room_id: &RoomId) -> Result<()> { self.db.typings_maintain(room_id) } - /// Returns the count of the last typing update in this room. - pub fn last_typing_update(&self, room_id: &RoomId) -> Result { - self.typings_maintain(room_id)?; + /// Returns the count of the last typing update in this room. + pub fn last_typing_update(&self, room_id: &RoomId) -> Result { + self.typings_maintain(room_id)?; - self.db.last_typing_update(room_id) - } + self.db.last_typing_update(room_id) + } - /// Returns a new typing EDU. - pub fn typings_all( - &self, - room_id: &RoomId, - ) -> Result> { - let user_ids = self.db.typings_all(room_id)?; + /// Returns a new typing EDU. + pub fn typings_all( + &self, room_id: &RoomId, + ) -> Result> { + let user_ids = self.db.typings_all(room_id)?; - Ok(SyncEphemeralRoomEvent { - content: ruma::events::typing::TypingEventContent { - user_ids: user_ids.into_iter().collect(), - }, - }) - } + Ok(SyncEphemeralRoomEvent { + content: ruma::events::typing::TypingEventContent { + user_ids: user_ids.into_iter().collect(), + }, + }) + } } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index 2b2e84ea..21864e12 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -1,1941 +1,1607 @@ /// An async function that can recursively call itself. type AsyncRecursiveType<'a, T> = Pin + 'a + Send>>; -use ruma::{ - api::federation::discovery::{get_remote_server_keys, get_server_keys}, - CanonicalJsonObject, CanonicalJsonValue, OwnedServerName, OwnedServerSigningKeyId, - RoomVersionId, -}; use std::{ - collections::{hash_map, HashSet}, - pin::Pin, - sync::RwLockWriteGuard, - time::{Duration, Instant, SystemTime}, + collections::{hash_map, HashSet}, + pin::Pin, + sync::RwLockWriteGuard, + time::{Duration, Instant, SystemTime}, }; -use tokio::sync::Semaphore; use futures_util::{stream::FuturesUnordered, Future, StreamExt}; use ruma::{ - api::{ - client::error::ErrorKind, - federation::{ - discovery::get_remote_server_keys_batch::{self, v2::QueryCriteria}, - event::{get_event, get_room_state_ids}, - membership::create_join_event, - }, - }, - events::{ - room::{create::RoomCreateEventContent, server_acl::RoomServerAclEventContent}, - StateEventType, - }, - int, - serde::Base64, - state_res::{self, RoomVersion, StateMap}, - uint, EventId, MilliSecondsSinceUnixEpoch, RoomId, ServerName, + api::{ + client::error::ErrorKind, + federation::{ + discovery::{ + get_remote_server_keys, + get_remote_server_keys_batch::{self, v2::QueryCriteria}, + get_server_keys, + }, + event::{get_event, get_room_state_ids}, + membership::create_join_event, + }, + }, + events::{ + room::{create::RoomCreateEventContent, server_acl::RoomServerAclEventContent}, + StateEventType, + }, + int, + serde::Base64, + state_res::{self, RoomVersion, StateMap}, + uint, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedServerName, + OwnedServerSigningKeyId, RoomId, RoomVersionId, ServerName, }; use serde_json::value::RawValue as RawJsonValue; +use tokio::sync::Semaphore; use tracing::{debug, error, info, trace, warn}; +use super::state_compressor::CompressedStateEvent; use crate::{ - service::{pdu, Arc, BTreeMap, HashMap, Result, RwLock}, - services, Error, PduEvent, + service::{pdu, Arc, BTreeMap, HashMap, Result, RwLock}, + services, Error, PduEvent, }; -use super::state_compressor::CompressedStateEvent; - type AsyncRecursiveCanonicalJsonVec<'a> = - AsyncRecursiveType<'a, Vec<(Arc, Option>)>>; + AsyncRecursiveType<'a, Vec<(Arc, Option>)>>; type AsyncRecursiveCanonicalJsonResult<'a> = - AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>>; + AsyncRecursiveType<'a, Result<(Arc, BTreeMap)>>; pub struct Service; impl Service { - /// When receiving an event one needs to: - /// 0. Check the server is in the room - /// 1. Skip the PDU if we already know about it - /// 1.1. Remove unsigned field - /// 2. Check signatures, otherwise drop - /// 3. Check content hash, redact if doesn't match - /// 4. Fetch any missing auth events doing all checks listed here starting at 1. These are not - /// timeline events - /// 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are - /// also rejected "due to auth events" - /// 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events - /// 7. Persist this event as an outlier - /// 8. If not timeline event: stop - /// 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline - /// events - /// 10. Fetch missing state and auth chain events by calling `/state_ids` at backwards extremities - /// doing all the checks in this list starting at 1. These are not timeline events - /// 11. Check the auth of the event passes based on the state of the event - /// 12. Ensure that the state is derived from the previous current state (i.e. we calculated by - /// doing state res where one of the inputs was a previously trusted set of state, don't just - /// trust a set of state we got from a remote) - /// 13. Use state resolution to find new room state - /// 14. Check if the event passes auth based on the "current state" of the room, if not soft fail it - // We use some AsyncRecursiveType hacks here so we can call this async funtion recursively - pub(crate) async fn handle_incoming_pdu<'a>( - &self, - origin: &'a ServerName, - event_id: &'a EventId, - room_id: &'a RoomId, - value: BTreeMap, - is_timeline_event: bool, - pub_key_map: &'a RwLock>>, - ) -> Result>> { - // 0. Check the server is in the room - if !services().rooms.metadata.exists(room_id)? { - return Err(Error::BadRequest( - ErrorKind::NotFound, - "Room is unknown to this server", - )); - } - - if services().rooms.metadata.is_disabled(room_id)? { - info!("Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and event ID {event_id}"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Federation of this room is currently disabled on this server.", - )); - } - - services().rooms.event_handler.acl_check(origin, room_id)?; - - // 1. Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? { - return Ok(Some(pdu_id)); - } - - let create_event = services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; - - let create_event_content: RoomCreateEventContent = - serde_json::from_str(create_event.content.get()).map_err(|e| { - error!("Invalid create event: {}", e); - Error::BadDatabase("Invalid create event in db") - })?; - let room_version_id = &create_event_content.room_version; - - let first_pdu_in_room = services() - .rooms - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; - - let (incoming_pdu, val) = self - .handle_outlier_pdu( - origin, - &create_event, - event_id, - room_id, - value, - false, - pub_key_map, - ) - .await?; - self.check_room_id(room_id, &incoming_pdu)?; - - // 8. if not timeline event: stop - if !is_timeline_event { - return Ok(None); - } - - // Skip old events - if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { - return Ok(None); - } - - // 9. Fetch any missing prev events doing all checks listed here starting at 1. These are timeline events - let (sorted_prev_events, mut eventid_info) = self - .fetch_unknown_prev_events( - origin, - &create_event, - room_id, - room_version_id, - pub_key_map, - incoming_pdu.prev_events.clone(), - ) - .await?; - - let mut errors = 0; - debug!(events = ?sorted_prev_events, "Got previous events"); - for prev_id in sorted_prev_events { - // Check for disabled again because it might have changed - if services().rooms.metadata.is_disabled(room_id)? { - info!("Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and event ID {event_id}"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Federation of this room is currently disabled on this server.", - )); - } - - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .unwrap() - .get(&*prev_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", prev_id); - continue; - } - } - - if errors >= 5 { - // Timeout other events - match services() - .globals - .bad_event_ratelimiter - .write() - .unwrap() - .entry((*prev_id).to_owned()) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1 + 1); - } - } - continue; - } - - if let Some((pdu, json)) = eventid_info.remove(&*prev_id) { - // Skip old events - if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { - continue; - } - - let start_time = Instant::now(); - services() - .globals - .roomid_federationhandletime - .write() - .unwrap() - .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); - - if let Err(e) = self - .upgrade_outlier_to_timeline_pdu( - pdu, - json, - &create_event, - origin, - room_id, - pub_key_map, - ) - .await - { - errors += 1; - warn!("Prev event {} failed: {}", prev_id, e); - match services() - .globals - .bad_event_ratelimiter - .write() - .unwrap() - .entry((*prev_id).to_owned()) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => { - *e.get_mut() = (Instant::now(), e.get().1 + 1); - } - } - } - let elapsed = start_time.elapsed(); - services() - .globals - .roomid_federationhandletime - .write() - .unwrap() - .remove(&room_id.to_owned()); - debug!( - "Handling prev event {} took {}m{}s", - prev_id, - elapsed.as_secs() / 60, - elapsed.as_secs() % 60 - ); - } - } - - // Done with prev events, now handling the incoming event - - let start_time = Instant::now(); - services() - .globals - .roomid_federationhandletime - .write() - .unwrap() - .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); - let r = services() - .rooms - .event_handler - .upgrade_outlier_to_timeline_pdu( - incoming_pdu, - val, - &create_event, - origin, - room_id, - pub_key_map, - ) - .await; - services() - .globals - .roomid_federationhandletime - .write() - .unwrap() - .remove(&room_id.to_owned()); - - r - } - - #[allow(clippy::too_many_arguments)] - fn handle_outlier_pdu<'a>( - &'a self, - origin: &'a ServerName, - create_event: &'a PduEvent, - event_id: &'a EventId, - room_id: &'a RoomId, - mut value: BTreeMap, - auth_events_known: bool, - pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveCanonicalJsonResult<'a> { - Box::pin(async move { - // 1. Remove unsigned field - value.remove("unsigned"); - - // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json - - // 2. Check signatures, otherwise drop - // 3. check content hash, redact if doesn't match - let create_event_content: RoomCreateEventContent = - serde_json::from_str(create_event.content.get()).map_err(|e| { - error!("Invalid create event: {}", e); - Error::BadDatabase("Invalid create event in db") - })?; - - let room_version_id = &create_event_content.room_version; - let room_version = - RoomVersion::new(room_version_id).expect("room version is supported"); - - let mut val = match ruma::signatures::verify_event( - &pub_key_map.read().expect("RwLock is poisoned."), - &value, - room_version_id, - ) { - Err(e) => { - // Drop - warn!("Dropping bad event {}: {}", event_id, e,); - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Signature verification failed", - )); - } - Ok(ruma::signatures::Verified::Signatures) => { - // Redact - warn!("Calculated hash does not match: {}", event_id); - let obj = match ruma::canonical_json::redact(value, room_version_id, None) { - Ok(obj) => obj, - Err(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Redaction failed", - )) - } - }; - - // Skip the PDU if it is redacted and we already have it as an outlier event - if services().rooms.timeline.get_pdu_json(event_id)?.is_some() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Event was redacted and we already knew about it", - )); - } - - obj - } - Ok(ruma::signatures::Verified::All) => value, - }; - - // Now that we have checked the signature and hashes we can add the eventID and convert - // to our PduEvent type - val.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(event_id.as_str().to_owned()), - ); - let incoming_pdu = serde_json::from_value::( - serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), - ) - .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; - - self.check_room_id(room_id, &incoming_pdu)?; - - if !auth_events_known { - // 4. fetch any missing auth events doing all checks listed here starting at 1. These are not timeline events - // 5. Reject "due to auth events" if can't get all the auth events or some of the auth events are also rejected "due to auth events" - // NOTE: Step 5 is not applied anymore because it failed too often - debug!(event_id = ?incoming_pdu.event_id, "Fetching auth events"); - self.fetch_and_handle_outliers( - origin, - &incoming_pdu - .auth_events - .iter() - .map(|x| Arc::from(&**x)) - .collect::>(), - create_event, - room_id, - room_version_id, - pub_key_map, - ) - .await; - } - - // 6. Reject "due to auth events" if the event doesn't pass auth based on the auth events - debug!( - "Auth check for {} based on auth events", - incoming_pdu.event_id - ); - - // Build map of auth events - let mut auth_events = HashMap::new(); - for id in &incoming_pdu.auth_events { - let auth_event = match services().rooms.timeline.get_pdu(id)? { - Some(e) => e, - None => { - warn!("Could not find auth event {}", id); - continue; - } - }; - - self.check_room_id(room_id, &auth_event)?; - - match auth_events.entry(( - auth_event.kind.to_string().into(), - auth_event - .state_key - .clone() - .expect("all auth events have state keys"), - )) { - hash_map::Entry::Vacant(v) => { - v.insert(auth_event); - } - hash_map::Entry::Occupied(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Auth event's type and state_key combination exists multiple times.", - )); - } - } - } - - // The original create event must be in the auth events - if !matches!( - auth_events - .get(&(StateEventType::RoomCreate, "".to_owned())) - .map(std::convert::AsRef::as_ref), - Some(_) | None - ) { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Incoming event refers to wrong create event.", - )); - } - - if !state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None::, // TODO: third party invite - |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), - ) - .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed"))? - { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Auth check failed", - )); - } - - debug!("Validation successful."); - - // 7. Persist the event as an outlier. - services() - .rooms - .outlier - .add_pdu_outlier(&incoming_pdu.event_id, &val)?; - - debug!("Added pdu as outlier."); - - Ok((Arc::new(incoming_pdu), val)) - }) - } - - pub async fn upgrade_outlier_to_timeline_pdu( - &self, - incoming_pdu: Arc, - val: BTreeMap, - create_event: &PduEvent, - origin: &ServerName, - room_id: &RoomId, - pub_key_map: &RwLock>>, - ) -> Result>> { - // Skip the PDU if we already have it as a timeline event - if let Ok(Some(pduid)) = services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id) { - return Ok(Some(pduid)); - } - - if services() - .rooms - .pdu_metadata - .is_event_soft_failed(&incoming_pdu.event_id)? - { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Event has been soft failed", - )); - } - - info!("Upgrading {} to timeline pdu", incoming_pdu.event_id); - - let create_event_content: RoomCreateEventContent = - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::BadDatabase("Invalid create event in db") - })?; - - let room_version_id = &create_event_content.room_version; - let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); - - // 10. Fetch missing state and auth chain events by calling /state_ids at backwards extremities - // doing all the checks in this list starting at 1. These are not timeline events. - - // TODO: if we know the prev_events of the incoming event we can avoid the request and build - // the state from a known point and resolve if > 1 prev_event - - debug!("Requesting state at event"); - let mut state_at_incoming_event = None; - - if incoming_pdu.prev_events.len() == 1 { - let prev_event = &*incoming_pdu.prev_events[0]; - let prev_event_sstatehash = services() - .rooms - .state_accessor - .pdu_shortstatehash(prev_event)?; - - let state = if let Some(shortstatehash) = prev_event_sstatehash { - Some( - services() - .rooms - .state_accessor - .state_full_ids(shortstatehash) - .await, - ) - } else { - None - }; - - if let Some(Ok(mut state)) = state { - debug!("Using cached state"); - let prev_pdu = services() - .rooms - .timeline - .get_pdu(prev_event) - .ok() - .flatten() - .ok_or_else(|| { - Error::bad_database("Could not find prev event, but we know the state.") - })?; - - if let Some(state_key) = &prev_pdu.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &prev_pdu.kind.to_string().into(), - state_key, - )?; - - state.insert(shortstatekey, Arc::from(prev_event)); - // Now it's the state after the pdu - } - - state_at_incoming_event = Some(state); - } - } else { - debug!("Calculating state at event using state res"); - let mut extremity_sstatehashes = HashMap::new(); - - let mut okay = true; - for prev_eventid in &incoming_pdu.prev_events { - let prev_event = - if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(prev_eventid) { - pdu - } else { - okay = false; - break; - }; - - let sstatehash = if let Ok(Some(s)) = services() - .rooms - .state_accessor - .pdu_shortstatehash(prev_eventid) - { - s - } else { - okay = false; - break; - }; - - extremity_sstatehashes.insert(sstatehash, prev_event); - } - - if okay { - let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); - let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); - - for (sstatehash, prev_event) in extremity_sstatehashes { - let mut leaf_state: HashMap<_, _> = services() - .rooms - .state_accessor - .state_full_ids(sstatehash) - .await?; - - if let Some(state_key) = &prev_event.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &prev_event.kind.to_string().into(), - state_key, - )?; - leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); - // Now it's the state after the pdu - } - - let mut state = StateMap::with_capacity(leaf_state.len()); - let mut starting_events = Vec::with_capacity(leaf_state.len()); - - for (k, id) in leaf_state { - if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) - { - // FIXME: Undo .to_string().into() when StateMap - // is updated to use StateEventType - state.insert((ty.to_string().into(), st_key), id.clone()); - } else { - warn!("Failed to get_statekey_from_short."); - } - starting_events.push(id); - } - - auth_chain_sets.push( - services() - .rooms - .auth_chain - .get_auth_chain(room_id, starting_events) - .await? - .collect(), - ); - - fork_states.push(state); - } - - let lock = services().globals.stateres_mutex.lock(); - - let result = - state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = services().rooms.timeline.get_pdu(id); - if let Err(e) = &res { - error!("Failed to fetch event: {}", e); - } - res.ok().flatten() - }); - drop(lock); - - state_at_incoming_event = match result { - Ok(new_state) => Some( - new_state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = - services().rooms.short.get_or_create_shortstatekey( - &event_type.to_string().into(), - &state_key, - )?; - Ok((shortstatekey, event_id)) - }) - .collect::>()?, - ), - Err(e) => { - warn!("State resolution on prev events failed, either an event could not be found or deserialization: {}", e); - None - } - } - } - } - - if state_at_incoming_event.is_none() { - debug!("Calling /state_ids"); - // Call /state_ids to find out what the state at this pdu is. We trust the server's - // response to some extend, but we still do a lot of checks on the events - match services() - .sending - .send_federation_request( - origin, - get_room_state_ids::v1::Request { - room_id: room_id.to_owned(), - event_id: (*incoming_pdu.event_id).to_owned(), - }, - ) - .await - { - Ok(res) => { - debug!("Fetching state events at event."); - let state_vec = self - .fetch_and_handle_outliers( - origin, - &res.pdu_ids - .iter() - .map(|x| Arc::from(&**x)) - .collect::>(), - create_event, - room_id, - room_version_id, - pub_key_map, - ) - .await; - - let mut state: HashMap<_, Arc> = HashMap::new(); - for (pdu, _) in state_vec { - let state_key = pdu.state_key.clone().ok_or_else(|| { - Error::bad_database("Found non-state pdu in state events.") - })?; - - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &pdu.kind.to_string().into(), - &state_key, - )?; - - match state.entry(shortstatekey) { - hash_map::Entry::Vacant(v) => { - v.insert(Arc::from(&*pdu.event_id)); - } - hash_map::Entry::Occupied(_) => return Err( - Error::bad_database("State event's type and state_key combination exists multiple times."), - ), - } - } - - // The original create event must still be in the state - let create_shortstatekey = services() - .rooms - .short - .get_shortstatekey(&StateEventType::RoomCreate, "")? - .expect("Room exists"); - - if state - .get(&create_shortstatekey) - .map(std::convert::AsRef::as_ref) - != Some(&create_event.event_id) - { - return Err(Error::bad_database( - "Incoming event refers to wrong create event.", - )); - } - - state_at_incoming_event = Some(state); - } - Err(e) => { - warn!("Fetching state for event failed: {}", e); - return Err(e); - } - }; - } - - let state_at_incoming_event = - state_at_incoming_event.expect("we always set this to some above"); - - debug!("Starting auth check"); - // 11. Check the auth of the event passes based on the state of the event - let check_result = state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None::, // TODO: third party invite - |k, s| { - services() - .rooms - .short - .get_shortstatekey(&k.to_string().into(), s) - .ok() - .flatten() - .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) - .and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten()) - }, - ) - .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?; - - if !check_result { - return Err(Error::bad_database( - "Event has failed auth check with state at the event.", - )); - } - debug!("Auth check succeeded"); - - // Soft fail check before doing state res - let auth_events = services().rooms.state.get_auth_events( - room_id, - &incoming_pdu.kind, - &incoming_pdu.sender, - incoming_pdu.state_key.as_deref(), - &incoming_pdu.content, - )?; - - let soft_fail = !state_res::event_auth::auth_check( - &room_version, - &incoming_pdu, - None::, - |k, s| auth_events.get(&(k.clone(), s.to_owned())), - ) - .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?; - - // 13. Use state resolution to find new room state - - // We start looking at current room state now, so lets lock the room - let mutex_state = Arc::clone( - services() - .globals - .roomid_mutex_state - .write() - .unwrap() - .entry(room_id.to_owned()) - .or_default(), - ); - let state_lock = mutex_state.lock().await; - - // Now we calculate the set of extremities this room has after the incoming event has been - // applied. We start with the previous extremities (aka leaves) - debug!("Calculating extremities"); - let mut extremities = services().rooms.state.get_forward_extremities(room_id)?; - debug!("Amount of forward extremities in room {room_id}: {extremities:?}"); - - // Remove any forward extremities that are referenced by this incoming event's prev_events - for prev_event in &incoming_pdu.prev_events { - if extremities.contains(prev_event) { - extremities.remove(prev_event); - } - } - - // Only keep those extremities were not referenced yet - extremities.retain(|id| { - !matches!( - services() - .rooms - .pdu_metadata - .is_event_referenced(room_id, id), - Ok(true) - ) - }); - - debug!("Compressing state at event"); - let state_ids_compressed = Arc::new( - state_at_incoming_event - .iter() - .map(|(shortstatekey, id)| { - services() - .rooms - .state_compressor - .compress_state_event(*shortstatekey, id) - }) - .collect::>()?, - ); - - if incoming_pdu.state_key.is_some() { - debug!("Preparing for stateres to derive new room state"); - - // We also add state after incoming event to the fork states - let mut state_after = state_at_incoming_event.clone(); - if let Some(state_key) = &incoming_pdu.state_key { - let shortstatekey = services().rooms.short.get_or_create_shortstatekey( - &incoming_pdu.kind.to_string().into(), - state_key, - )?; - - state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); - } - - let new_room_state = self - .resolve_state(room_id, room_version_id, state_after) - .await?; - - // Set the new room state to the resolved state - debug!("Forcing new room state"); - - let (sstatehash, new, removed) = services() - .rooms - .state_compressor - .save_state(room_id, new_room_state)?; - - services() - .rooms - .state - .force_state(room_id, sstatehash, new, removed, &state_lock) - .await?; - } - - // 14. Check if the event passes auth based on the "current state" of the room, if not soft fail it - debug!("Starting soft fail auth check"); - - if soft_fail { - services() - .rooms - .timeline - .append_incoming_pdu( - &incoming_pdu, - val, - extremities.iter().map(|e| (**e).to_owned()).collect(), - state_ids_compressed, - soft_fail, - &state_lock, - ) - .await?; - - // Soft fail, we keep the event as an outlier but don't add it to the timeline - warn!("Event was soft failed: {:?}", incoming_pdu); - services() - .rooms - .pdu_metadata - .mark_event_soft_failed(&incoming_pdu.event_id)?; - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Event has been soft failed", - )); - } - - debug!("Appending pdu to timeline"); - extremities.insert(incoming_pdu.event_id.clone()); - - // Now that the event has passed all auth it is added into the timeline. - // We use the `state_at_event` instead of `state_after` so we accurately - // represent the state for this event. - - let pdu_id = services() - .rooms - .timeline - .append_incoming_pdu( - &incoming_pdu, - val, - extremities.iter().map(|e| (**e).to_owned()).collect(), - state_ids_compressed, - soft_fail, - &state_lock, - ) - .await?; - - debug!("Appended incoming pdu"); - - // Event has passed all auth/stateres checks - drop(state_lock); - Ok(pdu_id) - } - - async fn resolve_state( - &self, - room_id: &RoomId, - room_version_id: &RoomVersionId, - incoming_state: HashMap>, - ) -> Result>> { - debug!("Loading current room state ids"); - let current_sstatehash = services() - .rooms - .state - .get_room_shortstatehash(room_id)? - .expect("every room has state"); - - let current_state_ids = services() - .rooms - .state_accessor - .state_full_ids(current_sstatehash) - .await?; - - let fork_states = [current_state_ids, incoming_state]; - - let mut auth_chain_sets = Vec::new(); - for state in &fork_states { - auth_chain_sets.push( - services() - .rooms - .auth_chain - .get_auth_chain(room_id, state.iter().map(|(_, id)| id.clone()).collect()) - .await? - .collect(), - ); - } - - debug!("Loading fork states"); - - let fork_states: Vec<_> = fork_states - .into_iter() - .map(|map| { - map.into_iter() - .filter_map(|(k, id)| { - services() - .rooms - .short - .get_statekey_from_short(k) - .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) - .ok() - }) - .collect::>() - }) - .collect(); - - debug!("Resolving state"); - - let lock = services().globals.stateres_mutex.lock(); - let state_resolve = - state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { - let res = services().rooms.timeline.get_pdu(id); - if let Err(e) = &res { - error!("Failed to fetch event: {}", e); - } - res.ok().flatten() - }); - - let state = match state_resolve { - Ok(new_state) => new_state, - Err(e) => { - error!("State resolution failed: {}", e); - return Err(Error::bad_database("State resolution failed, either an event could not be found or deserialization")); - } - }; - - drop(lock); - - debug!("State resolution done. Compressing state"); - - let new_room_state = state - .into_iter() - .map(|((event_type, state_key), event_id)| { - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; - services() - .rooms - .state_compressor - .compress_state_event(shortstatekey, &event_id) - }) - .collect::>()?; - - Ok(Arc::new(new_room_state)) - } - - /// Find the event and auth it. Once the event is validated (steps 1 - 8) - /// it is appended to the outliers Tree. - /// - /// Returns pdu and if we fetched it over federation the raw json. - /// - /// a. Look in the main timeline (pduid_pdu tree) - /// b. Look at outlier pdu tree - /// c. Ask origin server over federation - /// d. TODO: Ask other servers over federation? - #[tracing::instrument(skip_all)] - pub(crate) fn fetch_and_handle_outliers<'a>( - &'a self, - origin: &'a ServerName, - events: &'a [Arc], - create_event: &'a PduEvent, - room_id: &'a RoomId, - room_version_id: &'a RoomVersionId, - pub_key_map: &'a RwLock>>, - ) -> AsyncRecursiveCanonicalJsonVec<'a> { - Box::pin(async move { - let back_off = |id| match services() - .globals - .bad_event_ratelimiter - .write() - .unwrap() - .entry(id) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - }; - - let mut events_with_auth_events = vec![]; - for id in events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { - trace!("Found {} in db", id); - events_with_auth_events.push((id, Some(local_pdu), vec![])); - continue; - } - - // c. Ask origin server over federation - // We also handle its auth chain here so we don't get a stack overflow in - // handle_outlier_pdu. - let mut todo_auth_events = vec![Arc::clone(id)]; - let mut events_in_reverse_order = Vec::new(); - let mut events_all = HashSet::new(); - let mut i = 0; - while let Some(next_id) = todo_auth_events.pop() { - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .unwrap() - .get(&*next_id) - { - // Exponential backoff - let mut min_elapsed_duration = - Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", next_id); - continue; - } - } - - if events_all.contains(&next_id) { - continue; - } - - i += 1; - if i % 100 == 0 { - tokio::task::yield_now().await; - } - - if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) { - trace!("Found {} in db", next_id); - continue; - } - - info!("Fetching {} over federation.", next_id); - match services() - .sending - .send_federation_request( - origin, - get_event::v1::Request { - event_id: (*next_id).to_owned(), - }, - ) - .await - { - Ok(res) => { - info!("Got {} over federation", next_id); - let (calculated_event_id, value) = - match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) { - Ok(t) => t, - Err(_) => { - back_off((*next_id).to_owned()); - continue; - } - }; - - if calculated_event_id != *next_id { - warn!("Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", - next_id, calculated_event_id, &res.pdu); - } - - if let Some(auth_events) = - value.get("auth_events").and_then(|c| c.as_array()) - { - for auth_event in auth_events { - if let Ok(auth_event) = - serde_json::from_value(auth_event.clone().into()) - { - let a: Arc = auth_event; - todo_auth_events.push(a); - } else { - warn!("Auth event id is not valid"); - } - } - } else { - warn!("Auth event list invalid"); - } - - events_in_reverse_order.push((next_id.clone(), value)); - events_all.insert(next_id); - } - Err(e) => { - warn!("Failed to fetch event {} | {e}", next_id); - back_off((*next_id).to_owned()); - } - } - } - events_with_auth_events.push((id, None, events_in_reverse_order)); - } - - // We go through all the signatures we see on the PDUs and their unresolved - // dependencies and fetch the corresponding signing keys - info!("fetch_required_signing_keys for {}", origin); - self.fetch_required_signing_keys( - events_with_auth_events - .iter() - .flat_map(|(_id, _local_pdu, events)| events) - .map(|(_event_id, event)| event), - pub_key_map, - ) - .await - .unwrap_or_else(|e| { - warn!( - "Could not fetch all signatures for PDUs from {}: {:?}", - origin, e - ); - }); - - let mut pdus = vec![]; - for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { - // a. Look in the main timeline (pduid_pdu tree) - // b. Look at outlier pdu tree - // (get_pdu_json checks both) - if let Some(local_pdu) = local_pdu { - trace!("Found {} in db", id); - pdus.push((local_pdu, None)); - } - for (next_id, value) in events_in_reverse_order.iter().rev() { - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .unwrap() - .get(&**next_id) - { - // Exponential backoff - let mut min_elapsed_duration = - Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - info!("Backing off from {}", next_id); - continue; - } - } - - match self - .handle_outlier_pdu( - origin, - create_event, - next_id, - room_id, - value.clone(), - true, - pub_key_map, - ) - .await - { - Ok((pdu, json)) => { - if next_id == id { - pdus.push((pdu, Some(json))); - } - } - Err(e) => { - warn!("Authentication of event {} failed: {:?}", next_id, e); - back_off((**next_id).to_owned()); - } - } - } - } - pdus - }) - } - - async fn fetch_unknown_prev_events( - &self, - origin: &ServerName, - create_event: &PduEvent, - room_id: &RoomId, - room_version_id: &RoomVersionId, - pub_key_map: &RwLock>>, - initial_set: Vec>, - ) -> Result<( - Vec>, - HashMap, (Arc, BTreeMap)>, - )> { - let mut graph: HashMap, _> = HashMap::new(); - let mut eventid_info = HashMap::new(); - let mut todo_outlier_stack: Vec> = initial_set; - - let first_pdu_in_room = services() - .rooms - .timeline - .first_pdu_in_room(room_id)? - .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; - - let mut amount = 0; - - while let Some(prev_event_id) = todo_outlier_stack.pop() { - if let Some((pdu, json_opt)) = self - .fetch_and_handle_outliers( - origin, - &[prev_event_id.clone()], - create_event, - room_id, - room_version_id, - pub_key_map, - ) - .await - .pop() - { - self.check_room_id(room_id, &pdu)?; - - if amount > services().globals.max_fetch_prev_events() { - // Max limit reached - info!( - "Max prev event limit reached! Limit: {}", - services().globals.max_fetch_prev_events() - ); - graph.insert(prev_event_id.clone(), HashSet::new()); - continue; - } - - if let Some(json) = json_opt.or_else(|| { - services() - .rooms - .outlier - .get_outlier_pdu_json(&prev_event_id) - .ok() - .flatten() - }) { - if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { - amount += 1; - for prev_prev in &pdu.prev_events { - if !graph.contains_key(prev_prev) { - todo_outlier_stack.push(prev_prev.clone()); - } - } - - graph.insert( - prev_event_id.clone(), - pdu.prev_events.iter().cloned().collect(), - ); - } else { - // Time based check failed - graph.insert(prev_event_id.clone(), HashSet::new()); - } - - eventid_info.insert(prev_event_id.clone(), (pdu, json)); - } else { - // Get json failed, so this was not fetched over federation - graph.insert(prev_event_id.clone(), HashSet::new()); - } - } else { - // Fetch and handle failed - graph.insert(prev_event_id.clone(), HashSet::new()); - } - } - - let sorted = state_res::lexicographical_topological_sort(&graph, |event_id| { - // This return value is the key used for sorting events, - // events are then sorted by power level, time, - // and lexically by event_id. - Ok(( - int!(0), - MilliSecondsSinceUnixEpoch( - eventid_info - .get(event_id) - .map_or_else(|| uint!(0), |info| info.0.origin_server_ts), - ), - )) - }) - .map_err(|e| { - error!("Error sorting prev events: {e}"); - Error::bad_database("Error sorting prev events") - })?; - - Ok((sorted, eventid_info)) - } - - #[tracing::instrument(skip_all)] - pub(crate) async fn fetch_required_signing_keys<'a, E>( - &'a self, - events: E, - pub_key_map: &RwLock>>, - ) -> Result<()> - where - E: IntoIterator>, - { - let mut server_key_ids = HashMap::new(); - - for event in events.into_iter() { - debug!("Fetching keys for event: {event:?}"); - for (signature_server, signature) in event - .get("signatures") - .ok_or(Error::BadServerResponse( - "No signatures in server response pdu.", - ))? - .as_object() - .ok_or(Error::BadServerResponse( - "Invalid signatures object in server response pdu.", - ))? - { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; - - for signature_id in signature_object.keys() { - server_key_ids - .entry(signature_server.clone()) - .or_insert_with(HashSet::new) - .insert(signature_id.clone()); - } - } - } - - if server_key_ids.is_empty() { - // Nothing to do, can exit early - debug!("server_key_ids is empty, not fetching any keys"); - return Ok(()); - } - - info!( - "Fetch keys for {}", - server_key_ids - .keys() - .cloned() - .collect::>() - .join(", ") - ); - - let mut server_keys: FuturesUnordered<_> = server_key_ids - .into_iter() - .map(|(signature_server, signature_ids)| async { - let fetch_res = self - .fetch_signing_keys_for_server( - signature_server.as_str().try_into().map_err(|e| { - info!("Invalid servername in signatures of server response pdu: {e}"); - ( - signature_server.clone(), - Error::BadServerResponse( - "Invalid servername in signatures of server response pdu.", - ), - ) - })?, - signature_ids.into_iter().collect(), // HashSet to Vec - ) - .await; - - match fetch_res { - Ok(keys) => Ok((signature_server, keys)), - Err(e) => { - warn!("Signature verification failed: Could not fetch signing key for {signature_server}: {e}",); - Err((signature_server, e)) - } - } - }) - .collect(); - - while let Some(fetch_res) = server_keys.next().await { - match fetch_res { - Ok((signature_server, keys)) => { - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(signature_server.clone(), keys); - } - Err((signature_server, e)) => { - warn!("Failed to fetch keys for {}: {:?}", signature_server, e); - } - } - } - - Ok(()) - } - - // Gets a list of servers for which we don't have the signing key yet. We go over - // the PDUs and either cache the key or add it to the list that needs to be retrieved. - fn get_server_keys_from_cache( - &self, - pdu: &RawJsonValue, - servers: &mut BTreeMap>, - room_version: &RoomVersionId, - pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap>>, - ) -> Result<()> { - let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { - error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); - Error::BadServerResponse("Invalid PDU in server response") - })?; - - let event_id = format!( - "${}", - ruma::signatures::reference_hash(&value, room_version) - .expect("ruma can calculate reference hashes") - ); - let event_id = <&EventId>::try_from(event_id.as_str()) - .expect("ruma's reference hashes are valid event ids"); - - if let Some((time, tries)) = services() - .globals - .bad_event_ratelimiter - .read() - .unwrap() - .get(event_id) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {}", event_id); - return Err(Error::BadServerResponse("bad event, still backing off")); - } - } - - let signatures = value - .get("signatures") - .ok_or(Error::BadServerResponse( - "No signatures in server response pdu.", - ))? - .as_object() - .ok_or(Error::BadServerResponse( - "Invalid signatures object in server response pdu.", - ))?; - - for (signature_server, signature) in signatures { - let signature_object = signature.as_object().ok_or(Error::BadServerResponse( - "Invalid signatures content object in server response pdu.", - ))?; - - let signature_ids = signature_object.keys().cloned().collect::>(); - - let contains_all_ids = |keys: &BTreeMap| { - signature_ids.iter().all(|id| keys.contains_key(id)) - }; - - let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|e| { - info!("Invalid servername in signatures of server response pdu: {e}"); - Error::BadServerResponse("Invalid servername in signatures of server response pdu.") - })?; - - if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { - continue; - } - - debug!("Loading signing keys for {}", origin); - - let result: BTreeMap<_, _> = services() - .globals - .signing_keys_for(origin)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if !contains_all_ids(&result) { - debug!("Signing key not loaded for {}", origin); - servers.insert(origin.to_owned(), BTreeMap::new()); - } - - pub_key_map.insert(origin.to_string(), result); - } - - Ok(()) - } - - async fn batch_request_signing_keys( - &self, - mut servers: BTreeMap>, - pub_key_map: &RwLock>>, - ) -> Result<()> { - for server in services().globals.trusted_servers() { - info!("Asking batch signing keys from trusted server {}", server); - if let Ok(keys) = services() - .sending - .send_federation_request( - server, - get_remote_server_keys_batch::v2::Request { - server_keys: servers.clone(), - }, - ) - .await - { - debug!("Got signing keys: {:?}", keys); - let mut pkm = pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))?; - for k in keys.server_keys { - let k = match k.deserialize() { - Ok(key) => key, - Err(e) => { - warn!( - "Received error {} while fetching keys from trusted server {}", - e, server - ); - warn!("{}", k.into_json()); - continue; - } - }; - - // TODO: Check signature from trusted server? - servers.remove(&k.server_name); - - let result = services() - .globals - .add_signing_key(&k.server_name, k.clone())? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect::>(); - - pkm.insert(k.server_name.to_string(), result); - } - } - } - - Ok(()) - } - - async fn request_signing_keys( - &self, - servers: BTreeMap>, - pub_key_map: &RwLock>>, - ) -> Result<()> { - info!("Asking individual servers for signing keys: {servers:?}"); - let mut futures: FuturesUnordered<_> = servers - .into_keys() - .map(|server| async move { - ( - services() - .sending - .send_federation_request(&server, get_server_keys::v2::Request::new()) - .await, - server, - ) - }) - .collect(); - - while let Some(result) = futures.next().await { - debug!("Received new Future result"); - if let (Ok(get_keys_response), origin) = result { - info!("Result is from {origin}"); - if let Ok(key) = get_keys_response.server_key.deserialize() { - let result: BTreeMap<_, _> = services() - .globals - .add_signing_key(&origin, key)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))? - .insert(origin.to_string(), result); - } - } - debug!("Done handling Future result"); - } - - Ok(()) - } - - pub(crate) async fn fetch_join_signing_keys( - &self, - event: &create_join_event::v2::Response, - room_version: &RoomVersionId, - pub_key_map: &RwLock>>, - ) -> Result<()> { - let mut servers: BTreeMap< - OwnedServerName, - BTreeMap, - > = BTreeMap::new(); - - { - let mut pkm = pub_key_map - .write() - .map_err(|_| Error::bad_database("RwLock is poisoned."))?; - - // Try to fetch keys, failure is okay - // Servers we couldn't find in the cache will be added to `servers` - for pdu in &event.room_state.state { - let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); - } - for pdu in &event.room_state.auth_chain { - let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); - } - - drop(pkm); - }; - - if servers.is_empty() { - info!("We had all keys cached locally, not fetching any keys from remote servers"); - return Ok(()); - } - - if services().globals.query_trusted_key_servers_first() { - info!("query_trusted_key_servers_first is set to true, querying notary trusted key servers first for homeserver signing keys."); - - self.batch_request_signing_keys(servers.clone(), pub_key_map) - .await?; - - if servers.is_empty() { - info!("Trusted server supplied all signing keys, no more keys to fetch"); - return Ok(()); - } - - info!("Remaining servers left that the notary/trusted servers did not provide: {servers:?}"); - - self.request_signing_keys(servers.clone(), pub_key_map) - .await?; - } else { - info!("query_trusted_key_servers_first is set to false, querying individual homeservers first"); - - self.request_signing_keys(servers.clone(), pub_key_map) - .await?; - - if servers.is_empty() { - info!("Individual homeservers supplied all signing keys, no more keys to fetch"); - return Ok(()); - } - - info!("Remaining servers left the individual homeservers did not provide: {servers:?}"); - - self.batch_request_signing_keys(servers.clone(), pub_key_map) - .await?; - } - - info!("Search for signing keys done"); - - /*if servers.is_empty() { - warn!("Failed to find homeserver signing keys for the remaining servers: {servers:?}"); - }*/ - - Ok(()) - } - - /// Returns Ok if the acl allows the server - pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { - let acl_event = match services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomServerAcl, - "", - )? { - Some(acl) => { - debug!("ACL event found: {acl:?}"); - acl - } - None => { - info!("No ACL event found"); - return Ok(()); - } - }; - - let acl_event_content: RoomServerAclEventContent = - match serde_json::from_str(acl_event.content.get()) { - Ok(content) => { - debug!("Found ACL event contents: {content:?}"); - content - } - Err(e) => { - warn!("Invalid ACL event: {e}"); - return Ok(()); - } - }; - - if acl_event_content.allow.is_empty() { - warn!("Ignoring broken ACL event (allow key is empty)"); - // Ignore broken acl events - return Ok(()); - } - - if acl_event_content.is_allowed(server_name) { - debug!("server {server_name} is allowed by ACL"); - Ok(()) - } else { - info!( - "Server {} was denied by room ACL in {}", - server_name, room_id - ); - Err(Error::BadRequest( - ErrorKind::Forbidden, - "Server was denied by room ACL", - )) - } - } - - /// Search the DB for the signing keys of the given server, if we don't have them - /// fetch them from the server and save to our DB. - #[tracing::instrument(skip_all)] - pub async fn fetch_signing_keys_for_server( - &self, - origin: &ServerName, - signature_ids: Vec, - ) -> Result> { - let contains_all_ids = - |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); - - let permit = services() - .globals - .servername_ratelimiter - .read() - .unwrap() - .get(origin) - .map(|s| Arc::clone(s).acquire_owned()); - - let permit = match permit { - Some(p) => p, - None => { - let mut write = services().globals.servername_ratelimiter.write().unwrap(); - let s = Arc::clone( - write - .entry(origin.to_owned()) - .or_insert_with(|| Arc::new(Semaphore::new(1))), - ); - - s.acquire_owned() - } - } - .await; - - let back_off = |id| match services() - .globals - .bad_signature_ratelimiter - .write() - .unwrap() - .entry(id) - { - hash_map::Entry::Vacant(e) => { - e.insert((Instant::now(), 1)); - } - hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), - }; - - if let Some((time, tries)) = services() - .globals - .bad_signature_ratelimiter - .read() - .unwrap() - .get(&signature_ids) - { - // Exponential backoff - let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - debug!("Backing off from {:?}", signature_ids); - return Err(Error::BadServerResponse("bad signature, still backing off")); - } - } - - debug!("Loading signing keys for {}", origin); - - let mut result: BTreeMap<_, _> = services() - .globals - .signing_keys_for(origin)? - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)) - .collect(); - - if contains_all_ids(&result) { - return Ok(result); - } - - debug!("Fetching signing keys for {} over federation", origin); - - if let Some(server_key) = services() - .sending - .send_federation_request(origin, get_server_keys::v2::Request::new()) - .await - .ok() - .and_then(|resp| resp.server_key.deserialize().ok()) - { - services() - .globals - .add_signing_key(origin, server_key.clone())?; - - result.extend( - server_key - .verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - server_key - .old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - - if contains_all_ids(&result) { - return Ok(result); - } - } - - for server in services().globals.trusted_servers() { - debug!("Asking {} for {}'s signing key", server, origin); - if let Some(server_keys) = services() - .sending - .send_federation_request( - server, - get_remote_server_keys::v2::Request::new( - origin.to_owned(), - MilliSecondsSinceUnixEpoch::from_system_time( - SystemTime::now() - .checked_add(Duration::from_secs(3600)) - .expect("SystemTime too large"), - ) - .expect("time is valid"), - ), - ) - .await - .ok() - .map(|resp| { - resp.server_keys - .into_iter() - .filter_map(|e| e.deserialize().ok()) - .collect::>() - }) - { - debug!("Got signing keys: {:?}", server_keys); - for k in server_keys { - services().globals.add_signing_key(origin, k.clone())?; - result.extend( - k.verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - result.extend( - k.old_verify_keys - .into_iter() - .map(|(k, v)| (k.to_string(), v.key)), - ); - } - - if contains_all_ids(&result) { - return Ok(result); - } - } - } - - drop(permit); - - back_off(signature_ids); - - warn!("Failed to find public key for server: {}", origin); - Err(Error::BadServerResponse( - "Failed to find public key for server", - )) - } - - fn check_room_id(&self, room_id: &RoomId, pdu: &PduEvent) -> Result<()> { - if pdu.room_id != room_id { - warn!("Found event from room {} in room {}", pdu.room_id, room_id); - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Event has wrong room id", - )); - } - Ok(()) - } + /// When receiving an event one needs to: + /// 0. Check the server is in the room + /// 1. Skip the PDU if we already know about it + /// 1.1. Remove unsigned field + /// 2. Check signatures, otherwise drop + /// 3. Check content hash, redact if doesn't match + /// 4. Fetch any missing auth events doing all checks listed here starting + /// at 1. These are not timeline events + /// 5. Reject "due to auth events" if can't get all the auth events or some + /// of the auth events are also rejected "due to auth events" + /// 6. Reject "due to auth events" if the event doesn't pass auth based on + /// the auth events + /// 7. Persist this event as an outlier + /// 8. If not timeline event: stop + /// 9. Fetch any missing prev events doing all checks listed here starting + /// at 1. These are timeline events + /// 10. Fetch missing state and auth chain events by calling `/state_ids` at + /// backwards extremities doing all the checks in this list starting at + /// 1. These are not timeline events + /// 11. Check the auth of the event passes based on the state of the event + /// 12. Ensure that the state is derived from the previous current state + /// (i.e. we calculated by doing state res where one of the inputs was a + /// previously trusted set of state, don't just trust a set of state we + /// got from a remote) + /// 13. Use state resolution to find new room state + /// 14. Check if the event passes auth based on the "current state" of the + /// room, if not soft fail it + // We use some AsyncRecursiveType hacks here so we can call this async funtion + // recursively + pub(crate) async fn handle_incoming_pdu<'a>( + &self, origin: &'a ServerName, event_id: &'a EventId, room_id: &'a RoomId, + value: BTreeMap, is_timeline_event: bool, + pub_key_map: &'a RwLock>>, + ) -> Result>> { + // 0. Check the server is in the room + if !services().rooms.metadata.exists(room_id)? { + return Err(Error::BadRequest(ErrorKind::NotFound, "Room is unknown to this server")); + } + + if services().rooms.metadata.is_disabled(room_id)? { + info!( + "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \ + event ID {event_id}" + ); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Federation of this room is currently disabled on this server.", + )); + } + + services().rooms.event_handler.acl_check(origin, room_id)?; + + // 1. Skip the PDU if we already have it as a timeline event + if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(event_id)? { + return Ok(Some(pdu_id)); + } + + let create_event = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "")? + .ok_or_else(|| Error::bad_database("Failed to find create event in db."))?; + + let create_event_content: RoomCreateEventContent = + serde_json::from_str(create_event.content.get()).map_err(|e| { + error!("Invalid create event: {}", e); + Error::BadDatabase("Invalid create event in db") + })?; + let room_version_id = &create_event_content.room_version; + + let first_pdu_in_room = services() + .rooms + .timeline + .first_pdu_in_room(room_id)? + .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + + let (incoming_pdu, val) = + self.handle_outlier_pdu(origin, &create_event, event_id, room_id, value, false, pub_key_map).await?; + self.check_room_id(room_id, &incoming_pdu)?; + + // 8. if not timeline event: stop + if !is_timeline_event { + return Ok(None); + } + + // Skip old events + if incoming_pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { + return Ok(None); + } + + // 9. Fetch any missing prev events doing all checks listed here starting at 1. + // These are timeline events + let (sorted_prev_events, mut eventid_info) = self + .fetch_unknown_prev_events( + origin, + &create_event, + room_id, + room_version_id, + pub_key_map, + incoming_pdu.prev_events.clone(), + ) + .await?; + + let mut errors = 0; + debug!(events = ?sorted_prev_events, "Got previous events"); + for prev_id in sorted_prev_events { + // Check for disabled again because it might have changed + if services().rooms.metadata.is_disabled(room_id)? { + info!( + "Federaton of room {room_id} is currently disabled on this server. Request by origin {origin} and \ + event ID {event_id}" + ); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Federation of this room is currently disabled on this server.", + )); + } + + if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*prev_id) { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", prev_id); + continue; + } + } + + if errors >= 5 { + // Timeout other events + match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1); + }, + } + continue; + } + + if let Some((pdu, json)) = eventid_info.remove(&*prev_id) { + // Skip old events + if pdu.origin_server_ts < first_pdu_in_room.origin_server_ts { + continue; + } + + let start_time = Instant::now(); + services() + .globals + .roomid_federationhandletime + .write() + .unwrap() + .insert(room_id.to_owned(), ((*prev_id).to_owned(), start_time)); + + if let Err(e) = + self.upgrade_outlier_to_timeline_pdu(pdu, json, &create_event, origin, room_id, pub_key_map).await + { + errors += 1; + warn!("Prev event {} failed: {}", prev_id, e); + match services().globals.bad_event_ratelimiter.write().unwrap().entry((*prev_id).to_owned()) { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => { + *e.get_mut() = (Instant::now(), e.get().1 + 1); + }, + } + } + let elapsed = start_time.elapsed(); + services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned()); + debug!( + "Handling prev event {} took {}m{}s", + prev_id, + elapsed.as_secs() / 60, + elapsed.as_secs() % 60 + ); + } + } + + // Done with prev events, now handling the incoming event + + let start_time = Instant::now(); + services() + .globals + .roomid_federationhandletime + .write() + .unwrap() + .insert(room_id.to_owned(), (event_id.to_owned(), start_time)); + let r = services() + .rooms + .event_handler + .upgrade_outlier_to_timeline_pdu(incoming_pdu, val, &create_event, origin, room_id, pub_key_map) + .await; + services().globals.roomid_federationhandletime.write().unwrap().remove(&room_id.to_owned()); + + r + } + + #[allow(clippy::too_many_arguments)] + fn handle_outlier_pdu<'a>( + &'a self, origin: &'a ServerName, create_event: &'a PduEvent, event_id: &'a EventId, room_id: &'a RoomId, + mut value: BTreeMap, auth_events_known: bool, + pub_key_map: &'a RwLock>>, + ) -> AsyncRecursiveCanonicalJsonResult<'a> { + Box::pin(async move { + // 1. Remove unsigned field + value.remove("unsigned"); + + // TODO: For RoomVersion6 we must check that Raw<..> is canonical do we anywhere?: https://matrix.org/docs/spec/rooms/v6#canonical-json + + // 2. Check signatures, otherwise drop + // 3. check content hash, redact if doesn't match + let create_event_content: RoomCreateEventContent = serde_json::from_str(create_event.content.get()) + .map_err(|e| { + error!("Invalid create event: {}", e); + Error::BadDatabase("Invalid create event in db") + })?; + + let room_version_id = &create_event_content.room_version; + let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + + let mut val = match ruma::signatures::verify_event( + &pub_key_map.read().expect("RwLock is poisoned."), + &value, + room_version_id, + ) { + Err(e) => { + // Drop + warn!("Dropping bad event {}: {}", event_id, e,); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Signature verification failed")); + }, + Ok(ruma::signatures::Verified::Signatures) => { + // Redact + warn!("Calculated hash does not match: {}", event_id); + let obj = match ruma::canonical_json::redact(value, room_version_id, None) { + Ok(obj) => obj, + Err(_) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Redaction failed")), + }; + + // Skip the PDU if it is redacted and we already have it as an outlier event + if services().rooms.timeline.get_pdu_json(event_id)?.is_some() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Event was redacted and we already knew about it", + )); + } + + obj + }, + Ok(ruma::signatures::Verified::All) => value, + }; + + // Now that we have checked the signature and hashes we can add the eventID and + // convert to our PduEvent type + val.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned())); + let incoming_pdu = serde_json::from_value::( + serde_json::to_value(&val).expect("CanonicalJsonObj is a valid JsonValue"), + ) + .map_err(|_| Error::bad_database("Event is not a valid PDU."))?; + + self.check_room_id(room_id, &incoming_pdu)?; + + if !auth_events_known { + // 4. fetch any missing auth events doing all checks listed here starting at 1. + // These are not timeline events + // 5. Reject "due to auth events" if can't get all the auth events or some of + // the auth events are also rejected "due to auth events" + // NOTE: Step 5 is not applied anymore because it failed too often + debug!(event_id = ?incoming_pdu.event_id, "Fetching auth events"); + self.fetch_and_handle_outliers( + origin, + &incoming_pdu.auth_events.iter().map(|x| Arc::from(&**x)).collect::>(), + create_event, + room_id, + room_version_id, + pub_key_map, + ) + .await; + } + + // 6. Reject "due to auth events" if the event doesn't pass auth based on the + // auth events + debug!("Auth check for {} based on auth events", incoming_pdu.event_id); + + // Build map of auth events + let mut auth_events = HashMap::new(); + for id in &incoming_pdu.auth_events { + let auth_event = match services().rooms.timeline.get_pdu(id)? { + Some(e) => e, + None => { + warn!("Could not find auth event {}", id); + continue; + }, + }; + + self.check_room_id(room_id, &auth_event)?; + + match auth_events.entry(( + auth_event.kind.to_string().into(), + auth_event.state_key.clone().expect("all auth events have state keys"), + )) { + hash_map::Entry::Vacant(v) => { + v.insert(auth_event); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Auth event's type and state_key combination exists multiple times.", + )); + }, + } + } + + // The original create event must be in the auth events + if !matches!( + auth_events.get(&(StateEventType::RoomCreate, "".to_owned())).map(std::convert::AsRef::as_ref), + Some(_) | None + ) { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Incoming event refers to wrong create event.", + )); + } + + if !state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None::, // TODO: third party invite + |k, s| auth_events.get(&(k.to_string().into(), s.to_owned())), + ) + .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed"))? + { + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed")); + } + + debug!("Validation successful."); + + // 7. Persist the event as an outlier. + services().rooms.outlier.add_pdu_outlier(&incoming_pdu.event_id, &val)?; + + debug!("Added pdu as outlier."); + + Ok((Arc::new(incoming_pdu), val)) + }) + } + + pub async fn upgrade_outlier_to_timeline_pdu( + &self, incoming_pdu: Arc, val: BTreeMap, create_event: &PduEvent, + origin: &ServerName, room_id: &RoomId, pub_key_map: &RwLock>>, + ) -> Result>> { + // Skip the PDU if we already have it as a timeline event + if let Ok(Some(pduid)) = services().rooms.timeline.get_pdu_id(&incoming_pdu.event_id) { + return Ok(Some(pduid)); + } + + if services().rooms.pdu_metadata.is_event_soft_failed(&incoming_pdu.event_id)? { + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); + } + + info!("Upgrading {} to timeline pdu", incoming_pdu.event_id); + + let create_event_content: RoomCreateEventContent = + serde_json::from_str(create_event.content.get()).map_err(|e| { + warn!("Invalid create event: {}", e); + Error::BadDatabase("Invalid create event in db") + })?; + + let room_version_id = &create_event_content.room_version; + let room_version = RoomVersion::new(room_version_id).expect("room version is supported"); + + // 10. Fetch missing state and auth chain events by calling /state_ids at + // backwards extremities doing all the checks in this list starting at 1. + // These are not timeline events. + + // TODO: if we know the prev_events of the incoming event we can avoid the + // request and build the state from a known point and resolve if > 1 prev_event + + debug!("Requesting state at event"); + let mut state_at_incoming_event = None; + + if incoming_pdu.prev_events.len() == 1 { + let prev_event = &*incoming_pdu.prev_events[0]; + let prev_event_sstatehash = services().rooms.state_accessor.pdu_shortstatehash(prev_event)?; + + let state = if let Some(shortstatehash) = prev_event_sstatehash { + Some(services().rooms.state_accessor.state_full_ids(shortstatehash).await) + } else { + None + }; + + if let Some(Ok(mut state)) = state { + debug!("Using cached state"); + let prev_pdu = services() + .rooms + .timeline + .get_pdu(prev_event) + .ok() + .flatten() + .ok_or_else(|| Error::bad_database("Could not find prev event, but we know the state."))?; + + if let Some(state_key) = &prev_pdu.state_key { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&prev_pdu.kind.to_string().into(), state_key)?; + + state.insert(shortstatekey, Arc::from(prev_event)); + // Now it's the state after the pdu + } + + state_at_incoming_event = Some(state); + } + } else { + debug!("Calculating state at event using state res"); + let mut extremity_sstatehashes = HashMap::new(); + + let mut okay = true; + for prev_eventid in &incoming_pdu.prev_events { + let prev_event = if let Ok(Some(pdu)) = services().rooms.timeline.get_pdu(prev_eventid) { + pdu + } else { + okay = false; + break; + }; + + let sstatehash = if let Ok(Some(s)) = services().rooms.state_accessor.pdu_shortstatehash(prev_eventid) { + s + } else { + okay = false; + break; + }; + + extremity_sstatehashes.insert(sstatehash, prev_event); + } + + if okay { + let mut fork_states = Vec::with_capacity(extremity_sstatehashes.len()); + let mut auth_chain_sets = Vec::with_capacity(extremity_sstatehashes.len()); + + for (sstatehash, prev_event) in extremity_sstatehashes { + let mut leaf_state: HashMap<_, _> = + services().rooms.state_accessor.state_full_ids(sstatehash).await?; + + if let Some(state_key) = &prev_event.state_key { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&prev_event.kind.to_string().into(), state_key)?; + leaf_state.insert(shortstatekey, Arc::from(&*prev_event.event_id)); + // Now it's the state after the pdu + } + + let mut state = StateMap::with_capacity(leaf_state.len()); + let mut starting_events = Vec::with_capacity(leaf_state.len()); + + for (k, id) in leaf_state { + if let Ok((ty, st_key)) = services().rooms.short.get_statekey_from_short(k) { + // FIXME: Undo .to_string().into() when StateMap + // is updated to use StateEventType + state.insert((ty.to_string().into(), st_key), id.clone()); + } else { + warn!("Failed to get_statekey_from_short."); + } + starting_events.push(id); + } + + auth_chain_sets + .push(services().rooms.auth_chain.get_auth_chain(room_id, starting_events).await?.collect()); + + fork_states.push(state); + } + + let lock = services().globals.stateres_mutex.lock(); + + let result = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { + let res = services().rooms.timeline.get_pdu(id); + if let Err(e) = &res { + error!("Failed to fetch event: {}", e); + } + res.ok().flatten() + }); + drop(lock); + + state_at_incoming_event = match result { + Ok(new_state) => Some( + new_state + .into_iter() + .map(|((event_type, state_key), event_id)| { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; + Ok((shortstatekey, event_id)) + }) + .collect::>()?, + ), + Err(e) => { + warn!( + "State resolution on prev events failed, either an event could not be found or \ + deserialization: {}", + e + ); + None + }, + } + } + } + + if state_at_incoming_event.is_none() { + debug!("Calling /state_ids"); + // Call /state_ids to find out what the state at this pdu is. We trust the + // server's response to some extend, but we still do a lot of checks on the + // events + match services() + .sending + .send_federation_request( + origin, + get_room_state_ids::v1::Request { + room_id: room_id.to_owned(), + event_id: (*incoming_pdu.event_id).to_owned(), + }, + ) + .await + { + Ok(res) => { + debug!("Fetching state events at event."); + let state_vec = self + .fetch_and_handle_outliers( + origin, + &res.pdu_ids.iter().map(|x| Arc::from(&**x)).collect::>(), + create_event, + room_id, + room_version_id, + pub_key_map, + ) + .await; + + let mut state: HashMap<_, Arc> = HashMap::new(); + for (pdu, _) in state_vec { + let state_key = pdu + .state_key + .clone() + .ok_or_else(|| Error::bad_database("Found non-state pdu in state events."))?; + + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&pdu.kind.to_string().into(), &state_key)?; + + match state.entry(shortstatekey) { + hash_map::Entry::Vacant(v) => { + v.insert(Arc::from(&*pdu.event_id)); + }, + hash_map::Entry::Occupied(_) => { + return Err(Error::bad_database( + "State event's type and state_key combination exists multiple times.", + )) + }, + } + } + + // The original create event must still be in the state + let create_shortstatekey = services() + .rooms + .short + .get_shortstatekey(&StateEventType::RoomCreate, "")? + .expect("Room exists"); + + if state.get(&create_shortstatekey).map(std::convert::AsRef::as_ref) != Some(&create_event.event_id) + { + return Err(Error::bad_database("Incoming event refers to wrong create event.")); + } + + state_at_incoming_event = Some(state); + }, + Err(e) => { + warn!("Fetching state for event failed: {}", e); + return Err(e); + }, + }; + } + + let state_at_incoming_event = state_at_incoming_event.expect("we always set this to some above"); + + debug!("Starting auth check"); + // 11. Check the auth of the event passes based on the state of the event + let check_result = state_res::event_auth::auth_check( + &room_version, + &incoming_pdu, + None::, // TODO: third party invite + |k, s| { + services() + .rooms + .short + .get_shortstatekey(&k.to_string().into(), s) + .ok() + .flatten() + .and_then(|shortstatekey| state_at_incoming_event.get(&shortstatekey)) + .and_then(|event_id| services().rooms.timeline.get_pdu(event_id).ok().flatten()) + }, + ) + .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?; + + if !check_result { + return Err(Error::bad_database("Event has failed auth check with state at the event.")); + } + debug!("Auth check succeeded"); + + // Soft fail check before doing state res + let auth_events = services().rooms.state.get_auth_events( + room_id, + &incoming_pdu.kind, + &incoming_pdu.sender, + incoming_pdu.state_key.as_deref(), + &incoming_pdu.content, + )?; + + let soft_fail = !state_res::event_auth::auth_check(&room_version, &incoming_pdu, None::, |k, s| { + auth_events.get(&(k.clone(), s.to_owned())) + }) + .map_err(|_e| Error::BadRequest(ErrorKind::InvalidParam, "Auth check failed."))?; + + // 13. Use state resolution to find new room state + + // We start looking at current room state now, so lets lock the room + let mutex_state = + Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default()); + let state_lock = mutex_state.lock().await; + + // Now we calculate the set of extremities this room has after the incoming + // event has been applied. We start with the previous extremities (aka leaves) + debug!("Calculating extremities"); + let mut extremities = services().rooms.state.get_forward_extremities(room_id)?; + debug!("Amount of forward extremities in room {room_id}: {extremities:?}"); + + // Remove any forward extremities that are referenced by this incoming event's + // prev_events + for prev_event in &incoming_pdu.prev_events { + if extremities.contains(prev_event) { + extremities.remove(prev_event); + } + } + + // Only keep those extremities were not referenced yet + extremities.retain(|id| !matches!(services().rooms.pdu_metadata.is_event_referenced(room_id, id), Ok(true))); + + debug!("Compressing state at event"); + let state_ids_compressed = Arc::new( + state_at_incoming_event + .iter() + .map(|(shortstatekey, id)| services().rooms.state_compressor.compress_state_event(*shortstatekey, id)) + .collect::>()?, + ); + + if incoming_pdu.state_key.is_some() { + debug!("Preparing for stateres to derive new room state"); + + // We also add state after incoming event to the fork states + let mut state_after = state_at_incoming_event.clone(); + if let Some(state_key) = &incoming_pdu.state_key { + let shortstatekey = services() + .rooms + .short + .get_or_create_shortstatekey(&incoming_pdu.kind.to_string().into(), state_key)?; + + state_after.insert(shortstatekey, Arc::from(&*incoming_pdu.event_id)); + } + + let new_room_state = self.resolve_state(room_id, room_version_id, state_after).await?; + + // Set the new room state to the resolved state + debug!("Forcing new room state"); + + let (sstatehash, new, removed) = services().rooms.state_compressor.save_state(room_id, new_room_state)?; + + services().rooms.state.force_state(room_id, sstatehash, new, removed, &state_lock).await?; + } + + // 14. Check if the event passes auth based on the "current state" of the room, + // if not soft fail it + debug!("Starting soft fail auth check"); + + if soft_fail { + services() + .rooms + .timeline + .append_incoming_pdu( + &incoming_pdu, + val, + extremities.iter().map(|e| (**e).to_owned()).collect(), + state_ids_compressed, + soft_fail, + &state_lock, + ) + .await?; + + // Soft fail, we keep the event as an outlier but don't add it to the timeline + warn!("Event was soft failed: {:?}", incoming_pdu); + services().rooms.pdu_metadata.mark_event_soft_failed(&incoming_pdu.event_id)?; + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has been soft failed")); + } + + debug!("Appending pdu to timeline"); + extremities.insert(incoming_pdu.event_id.clone()); + + // Now that the event has passed all auth it is added into the timeline. + // We use the `state_at_event` instead of `state_after` so we accurately + // represent the state for this event. + + let pdu_id = services() + .rooms + .timeline + .append_incoming_pdu( + &incoming_pdu, + val, + extremities.iter().map(|e| (**e).to_owned()).collect(), + state_ids_compressed, + soft_fail, + &state_lock, + ) + .await?; + + debug!("Appended incoming pdu"); + + // Event has passed all auth/stateres checks + drop(state_lock); + Ok(pdu_id) + } + + async fn resolve_state( + &self, room_id: &RoomId, room_version_id: &RoomVersionId, incoming_state: HashMap>, + ) -> Result>> { + debug!("Loading current room state ids"); + let current_sstatehash = + services().rooms.state.get_room_shortstatehash(room_id)?.expect("every room has state"); + + let current_state_ids = services().rooms.state_accessor.state_full_ids(current_sstatehash).await?; + + let fork_states = [current_state_ids, incoming_state]; + + let mut auth_chain_sets = Vec::new(); + for state in &fork_states { + auth_chain_sets.push( + services() + .rooms + .auth_chain + .get_auth_chain(room_id, state.iter().map(|(_, id)| id.clone()).collect()) + .await? + .collect(), + ); + } + + debug!("Loading fork states"); + + let fork_states: Vec<_> = fork_states + .into_iter() + .map(|map| { + map.into_iter() + .filter_map(|(k, id)| { + services() + .rooms + .short + .get_statekey_from_short(k) + .map(|(ty, st_key)| ((ty.to_string().into(), st_key), id)) + .ok() + }) + .collect::>() + }) + .collect(); + + debug!("Resolving state"); + + let lock = services().globals.stateres_mutex.lock(); + let state_resolve = state_res::resolve(room_version_id, &fork_states, auth_chain_sets, |id| { + let res = services().rooms.timeline.get_pdu(id); + if let Err(e) = &res { + error!("Failed to fetch event: {}", e); + } + res.ok().flatten() + }); + + let state = match state_resolve { + Ok(new_state) => new_state, + Err(e) => { + error!("State resolution failed: {}", e); + return Err(Error::bad_database( + "State resolution failed, either an event could not be found or deserialization", + )); + }, + }; + + drop(lock); + + debug!("State resolution done. Compressing state"); + + let new_room_state = state + .into_iter() + .map(|((event_type, state_key), event_id)| { + let shortstatekey = + services().rooms.short.get_or_create_shortstatekey(&event_type.to_string().into(), &state_key)?; + services().rooms.state_compressor.compress_state_event(shortstatekey, &event_id) + }) + .collect::>()?; + + Ok(Arc::new(new_room_state)) + } + + /// Find the event and auth it. Once the event is validated (steps 1 - 8) + /// it is appended to the outliers Tree. + /// + /// Returns pdu and if we fetched it over federation the raw json. + /// + /// a. Look in the main timeline (pduid_pdu tree) + /// b. Look at outlier pdu tree + /// c. Ask origin server over federation + /// d. TODO: Ask other servers over federation? + #[tracing::instrument(skip_all)] + pub(crate) fn fetch_and_handle_outliers<'a>( + &'a self, origin: &'a ServerName, events: &'a [Arc], create_event: &'a PduEvent, room_id: &'a RoomId, + room_version_id: &'a RoomVersionId, pub_key_map: &'a RwLock>>, + ) -> AsyncRecursiveCanonicalJsonVec<'a> { + Box::pin(async move { + let back_off = |id| match services().globals.bad_event_ratelimiter.write().unwrap().entry(id) { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + }; + + let mut events_with_auth_events = vec![]; + for id in events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Ok(Some(local_pdu)) = services().rooms.timeline.get_pdu(id) { + trace!("Found {} in db", id); + events_with_auth_events.push((id, Some(local_pdu), vec![])); + continue; + } + + // c. Ask origin server over federation + // We also handle its auth chain here so we don't get a stack overflow in + // handle_outlier_pdu. + let mut todo_auth_events = vec![Arc::clone(id)]; + let mut events_in_reverse_order = Vec::new(); + let mut events_all = HashSet::new(); + let mut i = 0; + while let Some(next_id) = todo_auth_events.pop() { + if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(&*next_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", next_id); + continue; + } + } + + if events_all.contains(&next_id) { + continue; + } + + i += 1; + if i % 100 == 0 { + tokio::task::yield_now().await; + } + + if let Ok(Some(_)) = services().rooms.timeline.get_pdu(&next_id) { + trace!("Found {} in db", next_id); + continue; + } + + info!("Fetching {} over federation.", next_id); + match services() + .sending + .send_federation_request( + origin, + get_event::v1::Request { + event_id: (*next_id).to_owned(), + }, + ) + .await + { + Ok(res) => { + info!("Got {} over federation", next_id); + let (calculated_event_id, value) = + match pdu::gen_event_id_canonical_json(&res.pdu, room_version_id) { + Ok(t) => t, + Err(_) => { + back_off((*next_id).to_owned()); + continue; + }, + }; + + if calculated_event_id != *next_id { + warn!( + "Server didn't return event id we requested: requested: {}, we got {}. Event: {:?}", + next_id, calculated_event_id, &res.pdu + ); + } + + if let Some(auth_events) = value.get("auth_events").and_then(|c| c.as_array()) { + for auth_event in auth_events { + if let Ok(auth_event) = serde_json::from_value(auth_event.clone().into()) { + let a: Arc = auth_event; + todo_auth_events.push(a); + } else { + warn!("Auth event id is not valid"); + } + } + } else { + warn!("Auth event list invalid"); + } + + events_in_reverse_order.push((next_id.clone(), value)); + events_all.insert(next_id); + }, + Err(e) => { + warn!("Failed to fetch event {} | {e}", next_id); + back_off((*next_id).to_owned()); + }, + } + } + events_with_auth_events.push((id, None, events_in_reverse_order)); + } + + // We go through all the signatures we see on the PDUs and their unresolved + // dependencies and fetch the corresponding signing keys + info!("fetch_required_signing_keys for {}", origin); + self.fetch_required_signing_keys( + events_with_auth_events + .iter() + .flat_map(|(_id, _local_pdu, events)| events) + .map(|(_event_id, event)| event), + pub_key_map, + ) + .await + .unwrap_or_else(|e| { + warn!("Could not fetch all signatures for PDUs from {}: {:?}", origin, e); + }); + + let mut pdus = vec![]; + for (id, local_pdu, events_in_reverse_order) in events_with_auth_events { + // a. Look in the main timeline (pduid_pdu tree) + // b. Look at outlier pdu tree + // (get_pdu_json checks both) + if let Some(local_pdu) = local_pdu { + trace!("Found {} in db", id); + pdus.push((local_pdu, None)); + } + for (next_id, value) in events_in_reverse_order.iter().rev() { + if let Some((time, tries)) = + services().globals.bad_event_ratelimiter.read().unwrap().get(&**next_id) + { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + info!("Backing off from {}", next_id); + continue; + } + } + + match self + .handle_outlier_pdu(origin, create_event, next_id, room_id, value.clone(), true, pub_key_map) + .await + { + Ok((pdu, json)) => { + if next_id == id { + pdus.push((pdu, Some(json))); + } + }, + Err(e) => { + warn!("Authentication of event {} failed: {:?}", next_id, e); + back_off((**next_id).to_owned()); + }, + } + } + } + pdus + }) + } + + async fn fetch_unknown_prev_events( + &self, origin: &ServerName, create_event: &PduEvent, room_id: &RoomId, room_version_id: &RoomVersionId, + pub_key_map: &RwLock>>, initial_set: Vec>, + ) -> Result<( + Vec>, + HashMap, (Arc, BTreeMap)>, + )> { + let mut graph: HashMap, _> = HashMap::new(); + let mut eventid_info = HashMap::new(); + let mut todo_outlier_stack: Vec> = initial_set; + + let first_pdu_in_room = services() + .rooms + .timeline + .first_pdu_in_room(room_id)? + .ok_or_else(|| Error::bad_database("Failed to find first pdu in db."))?; + + let mut amount = 0; + + while let Some(prev_event_id) = todo_outlier_stack.pop() { + if let Some((pdu, json_opt)) = self + .fetch_and_handle_outliers( + origin, + &[prev_event_id.clone()], + create_event, + room_id, + room_version_id, + pub_key_map, + ) + .await + .pop() + { + self.check_room_id(room_id, &pdu)?; + + if amount > services().globals.max_fetch_prev_events() { + // Max limit reached + info!( + "Max prev event limit reached! Limit: {}", + services().globals.max_fetch_prev_events() + ); + graph.insert(prev_event_id.clone(), HashSet::new()); + continue; + } + + if let Some(json) = + json_opt.or_else(|| services().rooms.outlier.get_outlier_pdu_json(&prev_event_id).ok().flatten()) + { + if pdu.origin_server_ts > first_pdu_in_room.origin_server_ts { + amount += 1; + for prev_prev in &pdu.prev_events { + if !graph.contains_key(prev_prev) { + todo_outlier_stack.push(prev_prev.clone()); + } + } + + graph.insert(prev_event_id.clone(), pdu.prev_events.iter().cloned().collect()); + } else { + // Time based check failed + graph.insert(prev_event_id.clone(), HashSet::new()); + } + + eventid_info.insert(prev_event_id.clone(), (pdu, json)); + } else { + // Get json failed, so this was not fetched over federation + graph.insert(prev_event_id.clone(), HashSet::new()); + } + } else { + // Fetch and handle failed + graph.insert(prev_event_id.clone(), HashSet::new()); + } + } + + let sorted = state_res::lexicographical_topological_sort(&graph, |event_id| { + // This return value is the key used for sorting events, + // events are then sorted by power level, time, + // and lexically by event_id. + Ok(( + int!(0), + MilliSecondsSinceUnixEpoch( + eventid_info.get(event_id).map_or_else(|| uint!(0), |info| info.0.origin_server_ts), + ), + )) + }) + .map_err(|e| { + error!("Error sorting prev events: {e}"); + Error::bad_database("Error sorting prev events") + })?; + + Ok((sorted, eventid_info)) + } + + #[tracing::instrument(skip_all)] + pub(crate) async fn fetch_required_signing_keys<'a, E>( + &'a self, events: E, pub_key_map: &RwLock>>, + ) -> Result<()> + where + E: IntoIterator>, + { + let mut server_key_ids = HashMap::new(); + + for event in events.into_iter() { + debug!("Fetching keys for event: {event:?}"); + for (signature_server, signature) in event + .get("signatures") + .ok_or(Error::BadServerResponse("No signatures in server response pdu."))? + .as_object() + .ok_or(Error::BadServerResponse("Invalid signatures object in server response pdu."))? + { + let signature_object = signature.as_object().ok_or(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + ))?; + + for signature_id in signature_object.keys() { + server_key_ids + .entry(signature_server.clone()) + .or_insert_with(HashSet::new) + .insert(signature_id.clone()); + } + } + } + + if server_key_ids.is_empty() { + // Nothing to do, can exit early + debug!("server_key_ids is empty, not fetching any keys"); + return Ok(()); + } + + info!( + "Fetch keys for {}", + server_key_ids.keys().cloned().collect::>().join(", ") + ); + + let mut server_keys: FuturesUnordered<_> = server_key_ids + .into_iter() + .map(|(signature_server, signature_ids)| async { + let fetch_res = self + .fetch_signing_keys_for_server( + signature_server.as_str().try_into().map_err(|e| { + info!("Invalid servername in signatures of server response pdu: {e}"); + ( + signature_server.clone(), + Error::BadServerResponse("Invalid servername in signatures of server response pdu."), + ) + })?, + signature_ids.into_iter().collect(), // HashSet to Vec + ) + .await; + + match fetch_res { + Ok(keys) => Ok((signature_server, keys)), + Err(e) => { + warn!("Signature verification failed: Could not fetch signing key for {signature_server}: {e}",); + Err((signature_server, e)) + }, + } + }) + .collect(); + + while let Some(fetch_res) = server_keys.next().await { + match fetch_res { + Ok((signature_server, keys)) => { + pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))? + .insert(signature_server.clone(), keys); + }, + Err((signature_server, e)) => { + warn!("Failed to fetch keys for {}: {:?}", signature_server, e); + }, + } + } + + Ok(()) + } + + // Gets a list of servers for which we don't have the signing key yet. We go + // over the PDUs and either cache the key or add it to the list that needs to be + // retrieved. + fn get_server_keys_from_cache( + &self, pdu: &RawJsonValue, + servers: &mut BTreeMap>, + room_version: &RoomVersionId, + pub_key_map: &mut RwLockWriteGuard<'_, BTreeMap>>, + ) -> Result<()> { + let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| { + error!("Invalid PDU in server response: {:?}: {:?}", pdu, e); + Error::BadServerResponse("Invalid PDU in server response") + })?; + + let event_id = format!( + "${}", + ruma::signatures::reference_hash(&value, room_version).expect("ruma can calculate reference hashes") + ); + let event_id = <&EventId>::try_from(event_id.as_str()).expect("ruma's reference hashes are valid event ids"); + + if let Some((time, tries)) = services().globals.bad_event_ratelimiter.read().unwrap().get(event_id) { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {}", event_id); + return Err(Error::BadServerResponse("bad event, still backing off")); + } + } + + let signatures = value + .get("signatures") + .ok_or(Error::BadServerResponse("No signatures in server response pdu."))? + .as_object() + .ok_or(Error::BadServerResponse("Invalid signatures object in server response pdu."))?; + + for (signature_server, signature) in signatures { + let signature_object = signature.as_object().ok_or(Error::BadServerResponse( + "Invalid signatures content object in server response pdu.", + ))?; + + let signature_ids = signature_object.keys().cloned().collect::>(); + + let contains_all_ids = + |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); + + let origin = <&ServerName>::try_from(signature_server.as_str()).map_err(|e| { + info!("Invalid servername in signatures of server response pdu: {e}"); + Error::BadServerResponse("Invalid servername in signatures of server response pdu.") + })?; + + if servers.contains_key(origin) || pub_key_map.contains_key(origin.as_str()) { + continue; + } + + debug!("Loading signing keys for {}", origin); + + let result: BTreeMap<_, _> = + services().globals.signing_keys_for(origin)?.into_iter().map(|(k, v)| (k.to_string(), v.key)).collect(); + + if !contains_all_ids(&result) { + debug!("Signing key not loaded for {}", origin); + servers.insert(origin.to_owned(), BTreeMap::new()); + } + + pub_key_map.insert(origin.to_string(), result); + } + + Ok(()) + } + + async fn batch_request_signing_keys( + &self, mut servers: BTreeMap>, + pub_key_map: &RwLock>>, + ) -> Result<()> { + for server in services().globals.trusted_servers() { + info!("Asking batch signing keys from trusted server {}", server); + if let Ok(keys) = services() + .sending + .send_federation_request( + server, + get_remote_server_keys_batch::v2::Request { + server_keys: servers.clone(), + }, + ) + .await + { + debug!("Got signing keys: {:?}", keys); + let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?; + for k in keys.server_keys { + let k = match k.deserialize() { + Ok(key) => key, + Err(e) => { + warn!("Received error {} while fetching keys from trusted server {}", e, server); + warn!("{}", k.into_json()); + continue; + }, + }; + + // TODO: Check signature from trusted server? + servers.remove(&k.server_name); + + let result = services() + .globals + .add_signing_key(&k.server_name, k.clone())? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect::>(); + + pkm.insert(k.server_name.to_string(), result); + } + } + } + + Ok(()) + } + + async fn request_signing_keys( + &self, servers: BTreeMap>, + pub_key_map: &RwLock>>, + ) -> Result<()> { + info!("Asking individual servers for signing keys: {servers:?}"); + let mut futures: FuturesUnordered<_> = servers + .into_keys() + .map(|server| async move { + ( + services().sending.send_federation_request(&server, get_server_keys::v2::Request::new()).await, + server, + ) + }) + .collect(); + + while let Some(result) = futures.next().await { + debug!("Received new Future result"); + if let (Ok(get_keys_response), origin) = result { + info!("Result is from {origin}"); + if let Ok(key) = get_keys_response.server_key.deserialize() { + let result: BTreeMap<_, _> = services() + .globals + .add_signing_key(&origin, key)? + .into_iter() + .map(|(k, v)| (k.to_string(), v.key)) + .collect(); + pub_key_map + .write() + .map_err(|_| Error::bad_database("RwLock is poisoned."))? + .insert(origin.to_string(), result); + } + } + debug!("Done handling Future result"); + } + + Ok(()) + } + + pub(crate) async fn fetch_join_signing_keys( + &self, event: &create_join_event::v2::Response, room_version: &RoomVersionId, + pub_key_map: &RwLock>>, + ) -> Result<()> { + let mut servers: BTreeMap> = BTreeMap::new(); + + { + let mut pkm = pub_key_map.write().map_err(|_| Error::bad_database("RwLock is poisoned."))?; + + // Try to fetch keys, failure is okay + // Servers we couldn't find in the cache will be added to `servers` + for pdu in &event.room_state.state { + let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); + } + for pdu in &event.room_state.auth_chain { + let _ = self.get_server_keys_from_cache(pdu, &mut servers, room_version, &mut pkm); + } + + drop(pkm); + }; + + if servers.is_empty() { + info!("We had all keys cached locally, not fetching any keys from remote servers"); + return Ok(()); + } + + if services().globals.query_trusted_key_servers_first() { + info!( + "query_trusted_key_servers_first is set to true, querying notary trusted key servers first for \ + homeserver signing keys." + ); + + self.batch_request_signing_keys(servers.clone(), pub_key_map).await?; + + if servers.is_empty() { + info!("Trusted server supplied all signing keys, no more keys to fetch"); + return Ok(()); + } + + info!("Remaining servers left that the notary/trusted servers did not provide: {servers:?}"); + + self.request_signing_keys(servers.clone(), pub_key_map).await?; + } else { + info!("query_trusted_key_servers_first is set to false, querying individual homeservers first"); + + self.request_signing_keys(servers.clone(), pub_key_map).await?; + + if servers.is_empty() { + info!("Individual homeservers supplied all signing keys, no more keys to fetch"); + return Ok(()); + } + + info!("Remaining servers left the individual homeservers did not provide: {servers:?}"); + + self.batch_request_signing_keys(servers.clone(), pub_key_map).await?; + } + + info!("Search for signing keys done"); + + /*if servers.is_empty() { + warn!("Failed to find homeserver signing keys for the remaining servers: {servers:?}"); + }*/ + + Ok(()) + } + + /// Returns Ok if the acl allows the server + pub fn acl_check(&self, server_name: &ServerName, room_id: &RoomId) -> Result<()> { + let acl_event = + match services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomServerAcl, "")? { + Some(acl) => { + debug!("ACL event found: {acl:?}"); + acl + }, + None => { + info!("No ACL event found"); + return Ok(()); + }, + }; + + let acl_event_content: RoomServerAclEventContent = match serde_json::from_str(acl_event.content.get()) { + Ok(content) => { + debug!("Found ACL event contents: {content:?}"); + content + }, + Err(e) => { + warn!("Invalid ACL event: {e}"); + return Ok(()); + }, + }; + + if acl_event_content.allow.is_empty() { + warn!("Ignoring broken ACL event (allow key is empty)"); + // Ignore broken acl events + return Ok(()); + } + + if acl_event_content.is_allowed(server_name) { + debug!("server {server_name} is allowed by ACL"); + Ok(()) + } else { + info!("Server {} was denied by room ACL in {}", server_name, room_id); + Err(Error::BadRequest(ErrorKind::Forbidden, "Server was denied by room ACL")) + } + } + + /// Search the DB for the signing keys of the given server, if we don't have + /// them fetch them from the server and save to our DB. + #[tracing::instrument(skip_all)] + pub async fn fetch_signing_keys_for_server( + &self, origin: &ServerName, signature_ids: Vec, + ) -> Result> { + let contains_all_ids = |keys: &BTreeMap| signature_ids.iter().all(|id| keys.contains_key(id)); + + let permit = services() + .globals + .servername_ratelimiter + .read() + .unwrap() + .get(origin) + .map(|s| Arc::clone(s).acquire_owned()); + + let permit = match permit { + Some(p) => p, + None => { + let mut write = services().globals.servername_ratelimiter.write().unwrap(); + let s = Arc::clone(write.entry(origin.to_owned()).or_insert_with(|| Arc::new(Semaphore::new(1)))); + + s.acquire_owned() + }, + } + .await; + + let back_off = |id| match services().globals.bad_signature_ratelimiter.write().unwrap().entry(id) { + hash_map::Entry::Vacant(e) => { + e.insert((Instant::now(), 1)); + }, + hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1), + }; + + if let Some((time, tries)) = services().globals.bad_signature_ratelimiter.read().unwrap().get(&signature_ids) { + // Exponential backoff + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + debug!("Backing off from {:?}", signature_ids); + return Err(Error::BadServerResponse("bad signature, still backing off")); + } + } + + debug!("Loading signing keys for {}", origin); + + let mut result: BTreeMap<_, _> = + services().globals.signing_keys_for(origin)?.into_iter().map(|(k, v)| (k.to_string(), v.key)).collect(); + + if contains_all_ids(&result) { + return Ok(result); + } + + debug!("Fetching signing keys for {} over federation", origin); + + if let Some(server_key) = services() + .sending + .send_federation_request(origin, get_server_keys::v2::Request::new()) + .await + .ok() + .and_then(|resp| resp.server_key.deserialize().ok()) + { + services().globals.add_signing_key(origin, server_key.clone())?; + + result.extend(server_key.verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); + result.extend(server_key.old_verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); + + if contains_all_ids(&result) { + return Ok(result); + } + } + + for server in services().globals.trusted_servers() { + debug!("Asking {} for {}'s signing key", server, origin); + if let Some(server_keys) = services() + .sending + .send_federation_request( + server, + get_remote_server_keys::v2::Request::new( + origin.to_owned(), + MilliSecondsSinceUnixEpoch::from_system_time( + SystemTime::now().checked_add(Duration::from_secs(3600)).expect("SystemTime too large"), + ) + .expect("time is valid"), + ), + ) + .await + .ok() + .map(|resp| resp.server_keys.into_iter().filter_map(|e| e.deserialize().ok()).collect::>()) + { + debug!("Got signing keys: {:?}", server_keys); + for k in server_keys { + services().globals.add_signing_key(origin, k.clone())?; + result.extend(k.verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); + result.extend(k.old_verify_keys.into_iter().map(|(k, v)| (k.to_string(), v.key))); + } + + if contains_all_ids(&result) { + return Ok(result); + } + } + } + + drop(permit); + + back_off(signature_ids); + + warn!("Failed to find public key for server: {}", origin); + Err(Error::BadServerResponse("Failed to find public key for server")) + } + + fn check_room_id(&self, room_id: &RoomId, pdu: &PduEvent) -> Result<()> { + if pdu.room_id != room_id { + warn!("Found event from room {} in room {}", pdu.room_id, room_id); + return Err(Error::BadRequest(ErrorKind::InvalidParam, "Event has wrong room id")); + } + Ok(()) + } } diff --git a/src/service/rooms/lazy_loading/data.rs b/src/service/rooms/lazy_loading/data.rs index 9af8e21b..890a2f98 100644 --- a/src/service/rooms/lazy_loading/data.rs +++ b/src/service/rooms/lazy_loading/data.rs @@ -1,27 +1,16 @@ -use crate::Result; use ruma::{DeviceId, RoomId, UserId}; +use crate::Result; + pub trait Data: Send + Sync { - fn lazy_load_was_sent_before( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - ll_user: &UserId, - ) -> Result; + fn lazy_load_was_sent_before( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, + ) -> Result; - fn lazy_load_confirm_delivery( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - confirmed_user_ids: &mut dyn Iterator, - ) -> Result<()>; + fn lazy_load_confirm_delivery( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, + confirmed_user_ids: &mut dyn Iterator, + ) -> Result<()>; - fn lazy_load_reset( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - ) -> Result<()>; + fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()>; } diff --git a/src/service/rooms/lazy_loading/mod.rs b/src/service/rooms/lazy_loading/mod.rs index d466231e..b925fc0f 100644 --- a/src/service/rooms/lazy_loading/mod.rs +++ b/src/service/rooms/lazy_loading/mod.rs @@ -1,92 +1,62 @@ mod data; use std::{ - collections::{HashMap, HashSet}, - sync::Mutex, + collections::{HashMap, HashSet}, + sync::Mutex, }; pub use data::Data; use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId}; +use super::timeline::PduCount; use crate::Result; -use super::timeline::PduCount; - -type LazyLoadWaitingMutex = - Mutex>>; +type LazyLoadWaitingMutex = Mutex>>; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, - pub lazy_load_waiting: LazyLoadWaitingMutex, + pub lazy_load_waiting: LazyLoadWaitingMutex, } impl Service { - #[tracing::instrument(skip(self))] - pub fn lazy_load_was_sent_before( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - ll_user: &UserId, - ) -> Result { - self.db - .lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) - } + #[tracing::instrument(skip(self))] + pub fn lazy_load_was_sent_before( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId, + ) -> Result { + self.db.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user) + } - #[tracing::instrument(skip(self))] - pub fn lazy_load_mark_sent( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - lazy_load: HashSet, - count: PduCount, - ) { - self.lazy_load_waiting.lock().unwrap().insert( - ( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - count, - ), - lazy_load, - ); - } + #[tracing::instrument(skip(self))] + pub fn lazy_load_mark_sent( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet, + count: PduCount, + ) { + self.lazy_load_waiting + .lock() + .unwrap() + .insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load); + } - #[tracing::instrument(skip(self))] - pub fn lazy_load_confirm_delivery( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - since: PduCount, - ) -> Result<()> { - if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( - user_id.to_owned(), - device_id.to_owned(), - room_id.to_owned(), - since, - )) { - self.db.lazy_load_confirm_delivery( - user_id, - device_id, - room_id, - &mut user_ids.iter().map(|u| &**u), - )?; - } else { - // Ignore - } + #[tracing::instrument(skip(self))] + pub fn lazy_load_confirm_delivery( + &self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount, + ) -> Result<()> { + if let Some(user_ids) = self.lazy_load_waiting.lock().unwrap().remove(&( + user_id.to_owned(), + device_id.to_owned(), + room_id.to_owned(), + since, + )) { + self.db.lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?; + } else { + // Ignore + } - Ok(()) - } + Ok(()) + } - #[tracing::instrument(skip(self))] - pub fn lazy_load_reset( - &self, - user_id: &UserId, - device_id: &DeviceId, - room_id: &RoomId, - ) -> Result<()> { - self.db.lazy_load_reset(user_id, device_id, room_id) - } + #[tracing::instrument(skip(self))] + pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> { + self.db.lazy_load_reset(user_id, device_id, room_id) + } } diff --git a/src/service/rooms/metadata/data.rs b/src/service/rooms/metadata/data.rs index 7c7e10be..d702b203 100644 --- a/src/service/rooms/metadata/data.rs +++ b/src/service/rooms/metadata/data.rs @@ -1,12 +1,13 @@ -use crate::Result; use ruma::{OwnedRoomId, RoomId}; +use crate::Result; + pub trait Data: Send + Sync { - fn exists(&self, room_id: &RoomId) -> Result; - fn iter_ids<'a>(&'a self) -> Box> + 'a>; - fn is_disabled(&self, room_id: &RoomId) -> Result; - fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>; - fn is_banned(&self, room_id: &RoomId) -> Result; - fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()>; - fn list_banned_rooms<'a>(&'a self) -> Box> + 'a>; + fn exists(&self, room_id: &RoomId) -> Result; + fn iter_ids<'a>(&'a self) -> Box> + 'a>; + fn is_disabled(&self, room_id: &RoomId) -> Result; + fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()>; + fn is_banned(&self, room_id: &RoomId) -> Result; + fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()>; + fn list_banned_rooms<'a>(&'a self) -> Box> + 'a>; } diff --git a/src/service/rooms/metadata/mod.rs b/src/service/rooms/metadata/mod.rs index 69ae6dbc..500ddcff 100644 --- a/src/service/rooms/metadata/mod.rs +++ b/src/service/rooms/metadata/mod.rs @@ -6,37 +6,27 @@ use ruma::{OwnedRoomId, RoomId}; use crate::Result; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - /// Checks if a room exists. - #[tracing::instrument(skip(self))] - pub fn exists(&self, room_id: &RoomId) -> Result { - self.db.exists(room_id) - } + /// Checks if a room exists. + #[tracing::instrument(skip(self))] + pub fn exists(&self, room_id: &RoomId) -> Result { self.db.exists(room_id) } - pub fn iter_ids<'a>(&'a self) -> Box> + 'a> { - self.db.iter_ids() - } + pub fn iter_ids<'a>(&'a self) -> Box> + 'a> { self.db.iter_ids() } - pub fn is_disabled(&self, room_id: &RoomId) -> Result { - self.db.is_disabled(room_id) - } + pub fn is_disabled(&self, room_id: &RoomId) -> Result { self.db.is_disabled(room_id) } - pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { - self.db.disable_room(room_id, disabled) - } + pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> { + self.db.disable_room(room_id, disabled) + } - pub fn is_banned(&self, room_id: &RoomId) -> Result { - self.db.is_banned(room_id) - } + pub fn is_banned(&self, room_id: &RoomId) -> Result { self.db.is_banned(room_id) } - pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { - self.db.ban_room(room_id, banned) - } + pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) } - pub fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { - self.db.list_banned_rooms() - } + pub fn list_banned_rooms<'a>(&'a self) -> Box> + 'a> { + self.db.list_banned_rooms() + } } diff --git a/src/service/rooms/mod.rs b/src/service/rooms/mod.rs index f0739841..668efee5 100644 --- a/src/service/rooms/mod.rs +++ b/src/service/rooms/mod.rs @@ -19,44 +19,44 @@ pub mod timeline; pub mod user; pub trait Data: - alias::Data - + auth_chain::Data - + directory::Data - + edus::Data - + lazy_loading::Data - + metadata::Data - + outlier::Data - + pdu_metadata::Data - + search::Data - + short::Data - + state::Data - + state_accessor::Data - + state_cache::Data - + state_compressor::Data - + timeline::Data - + threads::Data - + user::Data + alias::Data + + auth_chain::Data + + directory::Data + + edus::Data + + lazy_loading::Data + + metadata::Data + + outlier::Data + + pdu_metadata::Data + + search::Data + + short::Data + + state::Data + + state_accessor::Data + + state_cache::Data + + state_compressor::Data + + timeline::Data + + threads::Data + + user::Data { } pub struct Service { - pub alias: alias::Service, - pub auth_chain: auth_chain::Service, - pub directory: directory::Service, - pub edus: edus::Service, - pub event_handler: event_handler::Service, - pub lazy_loading: lazy_loading::Service, - pub metadata: metadata::Service, - pub outlier: outlier::Service, - pub pdu_metadata: pdu_metadata::Service, - pub search: search::Service, - pub short: short::Service, - pub state: state::Service, - pub state_accessor: state_accessor::Service, - pub state_cache: state_cache::Service, - pub state_compressor: state_compressor::Service, - pub timeline: timeline::Service, - pub threads: threads::Service, - pub spaces: spaces::Service, - pub user: user::Service, + pub alias: alias::Service, + pub auth_chain: auth_chain::Service, + pub directory: directory::Service, + pub edus: edus::Service, + pub event_handler: event_handler::Service, + pub lazy_loading: lazy_loading::Service, + pub metadata: metadata::Service, + pub outlier: outlier::Service, + pub pdu_metadata: pdu_metadata::Service, + pub search: search::Service, + pub short: short::Service, + pub state: state::Service, + pub state_accessor: state_accessor::Service, + pub state_cache: state_cache::Service, + pub state_compressor: state_compressor::Service, + pub timeline: timeline::Service, + pub threads: threads::Service, + pub spaces: spaces::Service, + pub user: user::Service, } diff --git a/src/service/rooms/outlier/data.rs b/src/service/rooms/outlier/data.rs index 0ed521dd..18eb3190 100644 --- a/src/service/rooms/outlier/data.rs +++ b/src/service/rooms/outlier/data.rs @@ -3,7 +3,7 @@ use ruma::{CanonicalJsonObject, EventId}; use crate::{PduEvent, Result}; pub trait Data: Send + Sync { - fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result>; - fn get_outlier_pdu(&self, event_id: &EventId) -> Result>; - fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>; + fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result>; + fn get_outlier_pdu(&self, event_id: &EventId) -> Result>; + fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()>; } diff --git a/src/service/rooms/outlier/mod.rs b/src/service/rooms/outlier/mod.rs index dae41e4b..7a6a1d01 100644 --- a/src/service/rooms/outlier/mod.rs +++ b/src/service/rooms/outlier/mod.rs @@ -6,23 +6,21 @@ use ruma::{CanonicalJsonObject, EventId}; use crate::{PduEvent, Result}; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - /// Returns the pdu from the outlier tree. - pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_outlier_pdu_json(event_id) - } + /// Returns the pdu from the outlier tree. + pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + self.db.get_outlier_pdu_json(event_id) + } - /// Returns the pdu from the outlier tree. - pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { - self.db.get_outlier_pdu(event_id) - } + /// Returns the pdu from the outlier tree. + pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result> { self.db.get_outlier_pdu(event_id) } - /// Append the PDU as an outlier. - #[tracing::instrument(skip(self, pdu))] - pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { - self.db.add_pdu_outlier(event_id, pdu) - } + /// Append the PDU as an outlier. + #[tracing::instrument(skip(self, pdu))] + pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> { + self.db.add_pdu_outlier(event_id, pdu) + } } diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index a4df34cc..8d9a2058 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -1,20 +1,17 @@ use std::sync::Arc; -use crate::{service::rooms::timeline::PduCount, PduEvent, Result}; use ruma::{EventId, RoomId, UserId}; +use crate::{service::rooms::timeline::PduCount, PduEvent, Result}; + pub trait Data: Send + Sync { - fn add_relation(&self, from: u64, to: u64) -> Result<()>; - #[allow(clippy::type_complexity)] - fn relations_until<'a>( - &'a self, - user_id: &'a UserId, - room_id: u64, - target: u64, - until: PduCount, - ) -> Result> + 'a>>; - fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()>; - fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result; - fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>; - fn is_event_soft_failed(&self, event_id: &EventId) -> Result; + fn add_relation(&self, from: u64, to: u64) -> Result<()>; + #[allow(clippy::type_complexity)] + fn relations_until<'a>( + &'a self, user_id: &'a UserId, room_id: u64, target: u64, until: PduCount, + ) -> Result> + 'a>>; + fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()>; + fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result; + fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()>; + fn is_event_soft_failed(&self, event_id: &EventId) -> Result; } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index fcc8335e..1ffc8a79 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -3,188 +3,166 @@ use std::sync::Arc; pub use data::Data; use ruma::{ - api::client::relations::get_relating_events, - events::{relation::RelationType, TimelineEventType}, - EventId, RoomId, UserId, + api::client::relations::get_relating_events, + events::{relation::RelationType, TimelineEventType}, + EventId, RoomId, UserId, }; use serde::Deserialize; +use super::timeline::PduCount; use crate::{services, PduEvent, Result}; -use super::timeline::PduCount; - pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } #[derive(Clone, Debug, Deserialize)] struct ExtractRelType { - rel_type: RelationType, + rel_type: RelationType, } #[derive(Clone, Debug, Deserialize)] struct ExtractRelatesToEventId { - #[serde(rename = "m.relates_to")] - relates_to: ExtractRelType, + #[serde(rename = "m.relates_to")] + relates_to: ExtractRelType, } impl Service { - #[tracing::instrument(skip(self, from, to))] - pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { - match (from, to) { - (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), - _ => { - // TODO: Relations with backfilled pdus + #[tracing::instrument(skip(self, from, to))] + pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> { + match (from, to) { + (PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t), + _ => { + // TODO: Relations with backfilled pdus - Ok(()) - } - } - } + Ok(()) + }, + } + } - #[allow(clippy::too_many_arguments)] - pub fn paginate_relations_with_filter( - &self, - sender_user: &UserId, - room_id: &RoomId, - target: &EventId, - filter_event_type: Option, - filter_rel_type: Option, - from: PduCount, - to: Option, - limit: usize, - ) -> Result { - let next_token; + #[allow(clippy::too_many_arguments)] + pub fn paginate_relations_with_filter( + &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: Option, + filter_rel_type: Option, from: PduCount, to: Option, limit: usize, + ) -> Result { + let next_token; - //TODO: Fix ruma: match body.dir { - match ruma::api::Direction::Backward { - ruma::api::Direction::Forward => { - let events_after: Vec<_> = services() - .rooms - .pdu_metadata - .relations_until(sender_user, room_id, target, from)? // TODO: should be relations_after - .filter(|r| { - r.as_ref().map_or(true, |(_, pdu)| { - filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) - && if let Ok(content) = - serde_json::from_str::( - pdu.content.get(), - ) - { - filter_rel_type - .as_ref() - .map_or(true, |r| &content.relates_to.rel_type == r) - } else { - false - } - }) - }) - .take(limit) - .filter_map(std::result::Result::ok) // Filter out buggy events - .filter(|(_, pdu)| { - services() - .rooms - .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) - .unwrap_or(false) - }) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` - .collect(); + //TODO: Fix ruma: match body.dir { + match ruma::api::Direction::Backward { + ruma::api::Direction::Forward => { + let events_after: Vec<_> = services() + .rooms + .pdu_metadata + .relations_until(sender_user, room_id, target, from)? // TODO: should be relations_after + .filter(|r| { + r.as_ref().map_or(true, |(_, pdu)| { + filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) + && if let Ok(content) = + serde_json::from_str::(pdu.content.get()) + { + filter_rel_type + .as_ref() + .map_or(true, |r| &content.relates_to.rel_type == r) + } else { + false + } + }) + }) + .take(limit) + .filter_map(std::result::Result::ok) // Filter out buggy events + .filter(|(_, pdu)| { + services() + .rooms + .state_accessor + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .unwrap_or(false) + }) + .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .collect(); - next_token = events_after.last().map(|(count, _)| count).copied(); + next_token = events_after.last().map(|(count, _)| count).copied(); - let events_after: Vec<_> = events_after - .into_iter() - .rev() // relations are always most recent first - .map(|(_, pdu)| pdu.to_message_like_event()) - .collect(); + let events_after: Vec<_> = events_after + .into_iter() + .rev() // relations are always most recent first + .map(|(_, pdu)| pdu.to_message_like_event()) + .collect(); - Ok(get_relating_events::v1::Response { - chunk: events_after, - next_batch: next_token.map(|t| t.stringify()), - prev_batch: Some(from.stringify()), - }) - } - ruma::api::Direction::Backward => { - let events_before: Vec<_> = services() - .rooms - .pdu_metadata - .relations_until(sender_user, room_id, target, from)? - .filter(|r| { - r.as_ref().map_or(true, |(_, pdu)| { - filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) - && if let Ok(content) = - serde_json::from_str::( - pdu.content.get(), - ) - { - filter_rel_type - .as_ref() - .map_or(true, |r| &content.relates_to.rel_type == r) - } else { - false - } - }) - }) - .take(limit) - .filter_map(std::result::Result::ok) // Filter out buggy events - .filter(|(_, pdu)| { - services() - .rooms - .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) - .unwrap_or(false) - }) - .take_while(|&(k, _)| Some(k) != to) // Stop at `to` - .collect(); + Ok(get_relating_events::v1::Response { + chunk: events_after, + next_batch: next_token.map(|t| t.stringify()), + prev_batch: Some(from.stringify()), + }) + }, + ruma::api::Direction::Backward => { + let events_before: Vec<_> = services() + .rooms + .pdu_metadata + .relations_until(sender_user, room_id, target, from)? + .filter(|r| { + r.as_ref().map_or(true, |(_, pdu)| { + filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) + && if let Ok(content) = + serde_json::from_str::(pdu.content.get()) + { + filter_rel_type + .as_ref() + .map_or(true, |r| &content.relates_to.rel_type == r) + } else { + false + } + }) + }) + .take(limit) + .filter_map(std::result::Result::ok) // Filter out buggy events + .filter(|(_, pdu)| { + services() + .rooms + .state_accessor + .user_can_see_event(sender_user, room_id, &pdu.event_id) + .unwrap_or(false) + }) + .take_while(|&(k, _)| Some(k) != to) // Stop at `to` + .collect(); - next_token = events_before.last().map(|(count, _)| count).copied(); + next_token = events_before.last().map(|(count, _)| count).copied(); - let events_before: Vec<_> = events_before - .into_iter() - .map(|(_, pdu)| pdu.to_message_like_event()) - .collect(); + let events_before: Vec<_> = + events_before.into_iter().map(|(_, pdu)| pdu.to_message_like_event()).collect(); - Ok(get_relating_events::v1::Response { - chunk: events_before, - next_batch: next_token.map(|t| t.stringify()), - prev_batch: Some(from.stringify()), - }) - } - } - } + Ok(get_relating_events::v1::Response { + chunk: events_before, + next_batch: next_token.map(|t| t.stringify()), + prev_batch: Some(from.stringify()), + }) + }, + } + } - pub fn relations_until<'a>( - &'a self, - user_id: &'a UserId, - room_id: &'a RoomId, - target: &'a EventId, - until: PduCount, - ) -> Result> + 'a> { - let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; - let target = match services().rooms.timeline.get_pdu_count(target)? { - Some(PduCount::Normal(c)) => c, - // TODO: Support backfilled relations - _ => 0, // This will result in an empty iterator - }; - self.db.relations_until(user_id, room_id, target, until) - } + pub fn relations_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, + ) -> Result> + 'a> { + let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; + let target = match services().rooms.timeline.get_pdu_count(target)? { + Some(PduCount::Normal(c)) => c, + // TODO: Support backfilled relations + _ => 0, // This will result in an empty iterator + }; + self.db.relations_until(user_id, room_id, target, until) + } - #[tracing::instrument(skip(self, room_id, event_ids))] - pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { - self.db.mark_as_referenced(room_id, event_ids) - } + #[tracing::instrument(skip(self, room_id, event_ids))] + pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) -> Result<()> { + self.db.mark_as_referenced(room_id, event_ids) + } - #[tracing::instrument(skip(self))] - pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { - self.db.is_event_referenced(room_id, event_id) - } + #[tracing::instrument(skip(self))] + pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result { + self.db.is_event_referenced(room_id, event_id) + } - #[tracing::instrument(skip(self))] - pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { - self.db.mark_event_soft_failed(event_id) - } + #[tracing::instrument(skip(self))] + pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) } - #[tracing::instrument(skip(self))] - pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result { - self.db.is_event_soft_failed(event_id) - } + #[tracing::instrument(skip(self))] + pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result { self.db.is_event_soft_failed(event_id) } } diff --git a/src/service/rooms/search/data.rs b/src/service/rooms/search/data.rs index 88fd88e5..96439adf 100644 --- a/src/service/rooms/search/data.rs +++ b/src/service/rooms/search/data.rs @@ -1,10 +1,11 @@ -use crate::Result; use ruma::RoomId; +use crate::Result; + type SearchPdusResult<'a> = Result> + 'a>, Vec)>>; pub trait Data: Send + Sync { - fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; + fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()>; - fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a>; + fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a>; } diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 139169ef..e75f7d14 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,26 +1,24 @@ mod data; pub use data::Data; - -use crate::Result; use ruma::RoomId; +use crate::Result; + pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - #[tracing::instrument(skip(self))] - pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { - self.db.index_pdu(shortroomid, pdu_id, message_body) - } + #[tracing::instrument(skip(self))] + pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> { + self.db.index_pdu(shortroomid, pdu_id, message_body) + } - #[tracing::instrument(skip(self))] - pub fn search_pdus<'a>( - &'a self, - room_id: &RoomId, - search_string: &str, - ) -> Result> + 'a, Vec)>> { - self.db.search_pdus(room_id, search_string) - } + #[tracing::instrument(skip(self))] + pub fn search_pdus<'a>( + &'a self, room_id: &RoomId, search_string: &str, + ) -> Result> + 'a, Vec)>> { + self.db.search_pdus(room_id, search_string) + } } diff --git a/src/service/rooms/short/data.rs b/src/service/rooms/short/data.rs index 652c525b..aa891e7a 100644 --- a/src/service/rooms/short/data.rs +++ b/src/service/rooms/short/data.rs @@ -1,31 +1,24 @@ use std::sync::Arc; -use crate::Result; use ruma::{events::StateEventType, EventId, RoomId}; +use crate::Result; + pub trait Data: Send + Sync { - fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result; + fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result; - fn get_shortstatekey( - &self, - event_type: &StateEventType, - state_key: &str, - ) -> Result>; + fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result>; - fn get_or_create_shortstatekey( - &self, - event_type: &StateEventType, - state_key: &str, - ) -> Result; + fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result; - fn get_eventid_from_short(&self, shorteventid: u64) -> Result>; + fn get_eventid_from_short(&self, shorteventid: u64) -> Result>; - fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>; + fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)>; - /// Returns (shortstatehash, already_existed) - fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; + /// Returns (shortstatehash, already_existed) + fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)>; - fn get_shortroomid(&self, room_id: &RoomId) -> Result>; + fn get_shortroomid(&self, room_id: &RoomId) -> Result>; - fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result; + fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result; } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index 45fadd74..7e8623d2 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -7,48 +7,38 @@ use ruma::{events::StateEventType, EventId, RoomId}; use crate::Result; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { - self.db.get_or_create_shorteventid(event_id) - } + pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result { + self.db.get_or_create_shorteventid(event_id) + } - pub fn get_shortstatekey( - &self, - event_type: &StateEventType, - state_key: &str, - ) -> Result> { - self.db.get_shortstatekey(event_type, state_key) - } + pub fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result> { + self.db.get_shortstatekey(event_type, state_key) + } - pub fn get_or_create_shortstatekey( - &self, - event_type: &StateEventType, - state_key: &str, - ) -> Result { - self.db.get_or_create_shortstatekey(event_type, state_key) - } + pub fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result { + self.db.get_or_create_shortstatekey(event_type, state_key) + } - pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { - self.db.get_eventid_from_short(shorteventid) - } + pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result> { + self.db.get_eventid_from_short(shorteventid) + } - pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { - self.db.get_statekey_from_short(shortstatekey) - } + pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> { + self.db.get_statekey_from_short(shortstatekey) + } - /// Returns (shortstatehash, already_existed) - pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { - self.db.get_or_create_shortstatehash(state_hash) - } + /// Returns (shortstatehash, already_existed) + pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> { + self.db.get_or_create_shortstatehash(state_hash) + } - pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { - self.db.get_shortroomid(room_id) - } + pub fn get_shortroomid(&self, room_id: &RoomId) -> Result> { self.db.get_shortroomid(room_id) } - pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { - self.db.get_or_create_shortroomid(room_id) - } + pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result { + self.db.get_or_create_shortroomid(room_id) + } } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index caeacc52..2021df07 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -2,505 +2,419 @@ use std::sync::{Arc, Mutex}; use lru_cache::LruCache; use ruma::{ - api::{ - client::{ - error::ErrorKind, - space::{get_hierarchy, SpaceHierarchyRoomsChunk}, - }, - federation, - }, - events::{ - room::{ - avatar::RoomAvatarEventContent, - canonical_alias::RoomCanonicalAliasEventContent, - create::RoomCreateEventContent, - guest_access::{GuestAccess, RoomGuestAccessEventContent}, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - join_rules::{self, AllowRule, JoinRule, RoomJoinRulesEventContent}, - topic::RoomTopicEventContent, - }, - space::child::SpaceChildEventContent, - StateEventType, - }, - space::SpaceRoomJoinRule, - OwnedRoomId, RoomId, UserId, + api::{ + client::{ + error::ErrorKind, + space::{get_hierarchy, SpaceHierarchyRoomsChunk}, + }, + federation, + }, + events::{ + room::{ + avatar::RoomAvatarEventContent, + canonical_alias::RoomCanonicalAliasEventContent, + create::RoomCreateEventContent, + guest_access::{GuestAccess, RoomGuestAccessEventContent}, + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + join_rules::{self, AllowRule, JoinRule, RoomJoinRulesEventContent}, + topic::RoomTopicEventContent, + }, + space::child::SpaceChildEventContent, + StateEventType, + }, + space::SpaceRoomJoinRule, + OwnedRoomId, RoomId, UserId, }; - use tracing::{debug, error, warn}; use crate::{services, Error, PduEvent, Result}; pub enum CachedJoinRule { - //Simplified(SpaceRoomJoinRule), - Full(JoinRule), + //Simplified(SpaceRoomJoinRule), + Full(JoinRule), } pub struct CachedSpaceChunk { - chunk: SpaceHierarchyRoomsChunk, - children: Vec, - join_rule: CachedJoinRule, + chunk: SpaceHierarchyRoomsChunk, + children: Vec, + join_rule: CachedJoinRule, } pub struct Service { - pub roomid_spacechunk_cache: Mutex>>, + pub roomid_spacechunk_cache: Mutex>>, } impl Service { - pub async fn get_hierarchy( - &self, - sender_user: &UserId, - room_id: &RoomId, - limit: usize, - skip: usize, - max_depth: usize, - suggested_only: bool, - ) -> Result { - let mut left_to_skip = skip; + pub async fn get_hierarchy( + &self, sender_user: &UserId, room_id: &RoomId, limit: usize, skip: usize, max_depth: usize, + suggested_only: bool, + ) -> Result { + let mut left_to_skip = skip; - let mut rooms_in_path = Vec::new(); - let mut stack = vec![vec![room_id.to_owned()]]; - let mut results = Vec::new(); + let mut rooms_in_path = Vec::new(); + let mut stack = vec![vec![room_id.to_owned()]]; + let mut results = Vec::new(); - while let Some(current_room) = { - while stack.last().map_or(false, std::vec::Vec::is_empty) { - stack.pop(); - } - if !stack.is_empty() { - stack.last_mut().and_then(std::vec::Vec::pop) - } else { - None - } - } { - rooms_in_path.push(current_room.clone()); - if results.len() >= limit { - break; - } + while let Some(current_room) = { + while stack.last().map_or(false, std::vec::Vec::is_empty) { + stack.pop(); + } + if !stack.is_empty() { + stack.last_mut().and_then(std::vec::Vec::pop) + } else { + None + } + } { + rooms_in_path.push(current_room.clone()); + if results.len() >= limit { + break; + } - if let Some(cached) = self - .roomid_spacechunk_cache - .lock() - .unwrap() - .get_mut(¤t_room.clone()) - .as_ref() - { - if let Some(cached) = cached { - let allowed = match &cached.join_rule { - //CachedJoinRule::Simplified(s) => { - //self.handle_simplified_join_rule(s, sender_user, ¤t_room)? - //} - CachedJoinRule::Full(f) => { - self.handle_join_rule(f, sender_user, ¤t_room)? - } - }; - if allowed { - if left_to_skip > 0 { - left_to_skip -= 1; - } else { - results.push(cached.chunk.clone()); - } - if rooms_in_path.len() < max_depth { - stack.push(cached.children.clone()); - } - } - } - continue; - } + if let Some(cached) = self.roomid_spacechunk_cache.lock().unwrap().get_mut(¤t_room.clone()).as_ref() { + if let Some(cached) = cached { + let allowed = match &cached.join_rule { + //CachedJoinRule::Simplified(s) => { + //self.handle_simplified_join_rule(s, sender_user, ¤t_room)? + //} + CachedJoinRule::Full(f) => self.handle_join_rule(f, sender_user, ¤t_room)?, + }; + if allowed { + if left_to_skip > 0 { + left_to_skip -= 1; + } else { + results.push(cached.chunk.clone()); + } + if rooms_in_path.len() < max_depth { + stack.push(cached.children.clone()); + } + } + } + continue; + } - if let Some(current_shortstatehash) = services() - .rooms - .state - .get_room_shortstatehash(¤t_room)? - { - let state = services() - .rooms - .state_accessor - .state_full_ids(current_shortstatehash) - .await?; + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(¤t_room)? { + let state = services().rooms.state_accessor.state_full_ids(current_shortstatehash).await?; - let mut children_ids = Vec::new(); - let mut children_pdus = Vec::new(); - for (key, id) in state { - let (event_type, state_key) = - services().rooms.short.get_statekey_from_short(key)?; - if event_type != StateEventType::SpaceChild { - continue; - } + let mut children_ids = Vec::new(); + let mut children_pdus = Vec::new(); + for (key, id) in state { + let (event_type, state_key) = services().rooms.short.get_statekey_from_short(key)?; + if event_type != StateEventType::SpaceChild { + continue; + } - let pdu = services() - .rooms - .timeline - .get_pdu(&id)? - .ok_or_else(|| Error::bad_database("Event in space state not found"))?; + let pdu = services() + .rooms + .timeline + .get_pdu(&id)? + .ok_or_else(|| Error::bad_database("Event in space state not found"))?; - if serde_json::from_str::(pdu.content.get()) - .ok() - .map(|c| c.via) - .map_or(true, |v| v.is_empty()) - { - continue; - } + if serde_json::from_str::(pdu.content.get()) + .ok() + .map(|c| c.via) + .map_or(true, |v| v.is_empty()) + { + continue; + } - if let Ok(room_id) = OwnedRoomId::try_from(state_key) { - children_ids.push(room_id); - children_pdus.push(pdu); - } - } + if let Ok(room_id) = OwnedRoomId::try_from(state_key) { + children_ids.push(room_id); + children_pdus.push(pdu); + } + } - // TODO: Sort children - children_ids.reverse(); + // TODO: Sort children + children_ids.reverse(); - let chunk = self - .get_room_chunk(sender_user, ¤t_room, children_pdus) - .await; - if let Ok(chunk) = chunk { - if left_to_skip > 0 { - left_to_skip -= 1; - } else { - results.push(chunk.clone()); - } - let join_rule = services() - .rooms - .state_accessor - .room_state_get(¤t_room, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| c.join_rule) - .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") - }) - }) - .transpose()? - .unwrap_or(JoinRule::Invite); + let chunk = self.get_room_chunk(sender_user, ¤t_room, children_pdus).await; + if let Ok(chunk) = chunk { + if left_to_skip > 0 { + left_to_skip -= 1; + } else { + results.push(chunk.clone()); + } + let join_rule = services() + .rooms + .state_accessor + .room_state_get(¤t_room, &StateEventType::RoomJoinRules, "")? + .map(|s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomJoinRulesEventContent| c.join_rule) + .map_err(|e| { + error!("Invalid room join rule event in database: {}", e); + Error::BadDatabase("Invalid room join rule event in database.") + }) + }) + .transpose()? + .unwrap_or(JoinRule::Invite); - self.roomid_spacechunk_cache.lock().unwrap().insert( - current_room.clone(), - Some(CachedSpaceChunk { - chunk, - children: children_ids.clone(), - join_rule: CachedJoinRule::Full(join_rule), - }), - ); - } + self.roomid_spacechunk_cache.lock().unwrap().insert( + current_room.clone(), + Some(CachedSpaceChunk { + chunk, + children: children_ids.clone(), + join_rule: CachedJoinRule::Full(join_rule), + }), + ); + } - if rooms_in_path.len() < max_depth { - stack.push(children_ids); - } - } else if let Some(server) = current_room.server_name() { - if server == services().globals.server_name() { - continue; - } + if rooms_in_path.len() < max_depth { + stack.push(children_ids); + } + } else if let Some(server) = current_room.server_name() { + if server == services().globals.server_name() { + continue; + } - if !results.is_empty() { - // Early return so the client can see some data already - break; - } + if !results.is_empty() { + // Early return so the client can see some data already + break; + } - debug!("Asking {server} for /hierarchy"); - if let Ok(response) = services() - .sending - .send_federation_request( - server, - federation::space::get_hierarchy::v1::Request { - room_id: current_room.clone(), - suggested_only, - }, - ) - .await - { - debug!("Got response from {server} for /hierarchy\n{response:?}"); - let chunk = SpaceHierarchyRoomsChunk { - canonical_alias: response.room.canonical_alias, - name: response.room.name, - num_joined_members: response.room.num_joined_members, - room_id: response.room.room_id, - topic: response.room.topic, - world_readable: response.room.world_readable, - guest_can_join: response.room.guest_can_join, - avatar_url: response.room.avatar_url, - join_rule: response.room.join_rule.clone(), - room_type: response.room.room_type, - children_state: response.room.children_state, - }; - let children = response - .children - .iter() - .map(|c| c.room_id.clone()) - .collect::>(); + debug!("Asking {server} for /hierarchy"); + if let Ok(response) = services() + .sending + .send_federation_request( + server, + federation::space::get_hierarchy::v1::Request { + room_id: current_room.clone(), + suggested_only, + }, + ) + .await + { + debug!("Got response from {server} for /hierarchy\n{response:?}"); + let chunk = SpaceHierarchyRoomsChunk { + canonical_alias: response.room.canonical_alias, + name: response.room.name, + num_joined_members: response.room.num_joined_members, + room_id: response.room.room_id, + topic: response.room.topic, + world_readable: response.room.world_readable, + guest_can_join: response.room.guest_can_join, + avatar_url: response.room.avatar_url, + join_rule: response.room.join_rule.clone(), + room_type: response.room.room_type, + children_state: response.room.children_state, + }; + let children = response.children.iter().map(|c| c.room_id.clone()).collect::>(); - let join_rule = match response.room.join_rule { - SpaceRoomJoinRule::Invite => JoinRule::Invite, - SpaceRoomJoinRule::Knock => JoinRule::Knock, - SpaceRoomJoinRule::Private => JoinRule::Private, - SpaceRoomJoinRule::Restricted => { - JoinRule::Restricted(join_rules::Restricted { - allow: response - .room - .allowed_room_ids - .into_iter() - .map(AllowRule::room_membership) - .collect(), - }) - } - SpaceRoomJoinRule::KnockRestricted => { - JoinRule::KnockRestricted(join_rules::Restricted { - allow: response - .room - .allowed_room_ids - .into_iter() - .map(AllowRule::room_membership) - .collect(), - }) - } - SpaceRoomJoinRule::Public => JoinRule::Public, - _ => return Err(Error::BadServerResponse("Unknown join rule")), - }; - if self.handle_join_rule(&join_rule, sender_user, ¤t_room)? { - if left_to_skip > 0 { - left_to_skip -= 1; - } else { - results.push(chunk.clone()); - } - if rooms_in_path.len() < max_depth { - stack.push(children.clone()); - } - } + let join_rule = match response.room.join_rule { + SpaceRoomJoinRule::Invite => JoinRule::Invite, + SpaceRoomJoinRule::Knock => JoinRule::Knock, + SpaceRoomJoinRule::Private => JoinRule::Private, + SpaceRoomJoinRule::Restricted => JoinRule::Restricted(join_rules::Restricted { + allow: response.room.allowed_room_ids.into_iter().map(AllowRule::room_membership).collect(), + }), + SpaceRoomJoinRule::KnockRestricted => JoinRule::KnockRestricted(join_rules::Restricted { + allow: response.room.allowed_room_ids.into_iter().map(AllowRule::room_membership).collect(), + }), + SpaceRoomJoinRule::Public => JoinRule::Public, + _ => return Err(Error::BadServerResponse("Unknown join rule")), + }; + if self.handle_join_rule(&join_rule, sender_user, ¤t_room)? { + if left_to_skip > 0 { + left_to_skip -= 1; + } else { + results.push(chunk.clone()); + } + if rooms_in_path.len() < max_depth { + stack.push(children.clone()); + } + } - self.roomid_spacechunk_cache.lock().unwrap().insert( - current_room.clone(), - Some(CachedSpaceChunk { - chunk, - children, - join_rule: CachedJoinRule::Full(join_rule), - }), - ); + self.roomid_spacechunk_cache.lock().unwrap().insert( + current_room.clone(), + Some(CachedSpaceChunk { + chunk, + children, + join_rule: CachedJoinRule::Full(join_rule), + }), + ); - /* TODO: - for child in response.children { - roomid_spacechunk_cache.insert( - current_room.clone(), - CachedSpaceChunk { - chunk: child.chunk, - children, - join_rule, - }, - ); - } - */ - } else { - self.roomid_spacechunk_cache - .lock() - .unwrap() - .insert(current_room.clone(), None); - } - } - } + /* TODO: + for child in response.children { + roomid_spacechunk_cache.insert( + current_room.clone(), + CachedSpaceChunk { + chunk: child.chunk, + children, + join_rule, + }, + ); + } + */ + } else { + self.roomid_spacechunk_cache.lock().unwrap().insert(current_room.clone(), None); + } + } + } - Ok(get_hierarchy::v1::Response { - next_batch: if results.is_empty() { - None - } else { - Some((skip + results.len()).to_string()) - }, - rooms: results, - }) - } + Ok(get_hierarchy::v1::Response { + next_batch: if results.is_empty() { + None + } else { + Some((skip + results.len()).to_string()) + }, + rooms: results, + }) + } - async fn get_room_chunk( - &self, - sender_user: &UserId, - room_id: &RoomId, - children: Vec>, - ) -> Result { - Ok(SpaceHierarchyRoomsChunk { - canonical_alias: services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCanonicalAlias, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomCanonicalAliasEventContent| c.alias) - .map_err(|_| { - Error::bad_database("Invalid canonical alias event in database.") - }) - })?, - name: services().rooms.state_accessor.get_name(room_id)?, - num_joined_members: services() - .rooms - .state_cache - .room_joined_count(room_id)? - .unwrap_or_else(|| { - warn!("Room {} has no member count", room_id); - 0 - }) - .try_into() - .expect("user count should not be that big"), - room_id: room_id.to_owned(), - topic: services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomTopic, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomTopicEventContent| Some(c.topic)) - .map_err(|_| { - error!("Invalid room topic event in database for room {}", room_id); - Error::bad_database("Invalid room topic event in database.") - }) - }) - .unwrap_or(None), - world_readable: services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| { - c.history_visibility == HistoryVisibility::WorldReadable - }) - .map_err(|_| { - Error::bad_database( - "Invalid room history visibility event in database.", - ) - }) - })?, - guest_can_join: services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomGuestAccess, "")? - .map_or(Ok(false), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomGuestAccessEventContent| { - c.guest_access == GuestAccess::CanJoin - }) - .map_err(|_| { - Error::bad_database("Invalid room guest access event in database.") - }) - })?, - avatar_url: services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomAvatar, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomAvatarEventContent| c.url) - .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) - }) - .transpose()? - // url is now an Option so we must flatten - .flatten(), - join_rule: { - let join_rule = services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? - .map(|s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomJoinRulesEventContent| c.join_rule) - .map_err(|e| { - error!("Invalid room join rule event in database: {}", e); - Error::BadDatabase("Invalid room join rule event in database.") - }) - }) - .transpose()? - .unwrap_or(JoinRule::Invite); + async fn get_room_chunk( + &self, sender_user: &UserId, room_id: &RoomId, children: Vec>, + ) -> Result { + Ok(SpaceHierarchyRoomsChunk { + canonical_alias: services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCanonicalAlias, "")? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomCanonicalAliasEventContent| c.alias) + .map_err(|_| Error::bad_database("Invalid canonical alias event in database.")) + })?, + name: services().rooms.state_accessor.get_name(room_id)?, + num_joined_members: services() + .rooms + .state_cache + .room_joined_count(room_id)? + .unwrap_or_else(|| { + warn!("Room {} has no member count", room_id); + 0 + }) + .try_into() + .expect("user count should not be that big"), + room_id: room_id.to_owned(), + topic: services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomTopic, "")? + .map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()).map(|c: RoomTopicEventContent| Some(c.topic)).map_err(|_| { + error!("Invalid room topic event in database for room {}", room_id); + Error::bad_database("Invalid room topic event in database.") + }) + }) + .unwrap_or(None), + world_readable: services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? + .map_or(Ok(false), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomHistoryVisibilityEventContent| { + c.history_visibility == HistoryVisibility::WorldReadable + }) + .map_err(|_| Error::bad_database("Invalid room history visibility event in database.")) + })?, + guest_can_join: services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomGuestAccess, "")? + .map_or(Ok(false), |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin) + .map_err(|_| Error::bad_database("Invalid room guest access event in database.")) + })?, + avatar_url: services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomAvatar, "")? + .map(|s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomAvatarEventContent| c.url) + .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) + }) + .transpose()? + // url is now an Option so we must flatten + .flatten(), + join_rule: { + let join_rule = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomJoinRules, "")? + .map(|s| { + serde_json::from_str(s.content.get()).map(|c: RoomJoinRulesEventContent| c.join_rule).map_err( + |e| { + error!("Invalid room join rule event in database: {}", e); + Error::BadDatabase("Invalid room join rule event in database.") + }, + ) + }) + .transpose()? + .unwrap_or(JoinRule::Invite); - if !self.handle_join_rule(&join_rule, sender_user, room_id)? { - debug!("User is not allowed to see room {room_id}"); - // This error will be caught later - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "User is not allowed to see the room", - )); - } + if !self.handle_join_rule(&join_rule, sender_user, room_id)? { + debug!("User is not allowed to see room {room_id}"); + // This error will be caught later + return Err(Error::BadRequest(ErrorKind::Forbidden, "User is not allowed to see the room")); + } - self.translate_joinrule(&join_rule)? - }, - room_type: services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .map(|s| { - serde_json::from_str::(s.content.get()).map_err(|e| { - error!("Invalid room create event in database: {}", e); - Error::BadDatabase("Invalid room create event in database.") - }) - }) - .transpose()? - .and_then(|e| e.room_type), - children_state: children - .into_iter() - .map(|pdu| pdu.to_stripped_spacechild_state_event()) - .collect(), - }) - } + self.translate_joinrule(&join_rule)? + }, + room_type: services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "")? + .map(|s| { + serde_json::from_str::(s.content.get()).map_err(|e| { + error!("Invalid room create event in database: {}", e); + Error::BadDatabase("Invalid room create event in database.") + }) + }) + .transpose()? + .and_then(|e| e.room_type), + children_state: children.into_iter().map(|pdu| pdu.to_stripped_spacechild_state_event()).collect(), + }) + } - fn translate_joinrule(&self, join_rule: &JoinRule) -> Result { - match join_rule { - JoinRule::Invite => Ok(SpaceRoomJoinRule::Invite), - JoinRule::Knock => Ok(SpaceRoomJoinRule::Knock), - JoinRule::Private => Ok(SpaceRoomJoinRule::Private), - JoinRule::Restricted(_) => Ok(SpaceRoomJoinRule::Restricted), - JoinRule::KnockRestricted(_) => Ok(SpaceRoomJoinRule::KnockRestricted), - JoinRule::Public => Ok(SpaceRoomJoinRule::Public), - _ => Err(Error::BadServerResponse("Unknown join rule")), - } - } + fn translate_joinrule(&self, join_rule: &JoinRule) -> Result { + match join_rule { + JoinRule::Invite => Ok(SpaceRoomJoinRule::Invite), + JoinRule::Knock => Ok(SpaceRoomJoinRule::Knock), + JoinRule::Private => Ok(SpaceRoomJoinRule::Private), + JoinRule::Restricted(_) => Ok(SpaceRoomJoinRule::Restricted), + JoinRule::KnockRestricted(_) => Ok(SpaceRoomJoinRule::KnockRestricted), + JoinRule::Public => Ok(SpaceRoomJoinRule::Public), + _ => Err(Error::BadServerResponse("Unknown join rule")), + } + } - fn handle_simplified_join_rule( - &self, - join_rule: &SpaceRoomJoinRule, - sender_user: &UserId, - room_id: &RoomId, - ) -> Result { - let allowed = match join_rule { - SpaceRoomJoinRule::Public => true, - SpaceRoomJoinRule::Knock => true, - SpaceRoomJoinRule::Invite => services() - .rooms - .state_cache - .is_joined(sender_user, room_id)?, - _ => false, - }; + fn handle_simplified_join_rule( + &self, join_rule: &SpaceRoomJoinRule, sender_user: &UserId, room_id: &RoomId, + ) -> Result { + let allowed = match join_rule { + SpaceRoomJoinRule::Public => true, + SpaceRoomJoinRule::Knock => true, + SpaceRoomJoinRule::Invite => services().rooms.state_cache.is_joined(sender_user, room_id)?, + _ => false, + }; - Ok(allowed) - } + Ok(allowed) + } - fn handle_join_rule( - &self, - join_rule: &JoinRule, - sender_user: &UserId, - room_id: &RoomId, - ) -> Result { - if self.handle_simplified_join_rule( - &self.translate_joinrule(join_rule)?, - sender_user, - room_id, - )? { - return Ok(true); - } + fn handle_join_rule(&self, join_rule: &JoinRule, sender_user: &UserId, room_id: &RoomId) -> Result { + if self.handle_simplified_join_rule(&self.translate_joinrule(join_rule)?, sender_user, room_id)? { + return Ok(true); + } - match join_rule { - JoinRule::Restricted(r) => { - for rule in &r.allow { - if let join_rules::AllowRule::RoomMembership(rm) = rule { - if let Ok(true) = services() - .rooms - .state_cache - .is_joined(sender_user, &rm.room_id) - { - return Ok(true); - } - } - } + match join_rule { + JoinRule::Restricted(r) => { + for rule in &r.allow { + if let join_rules::AllowRule::RoomMembership(rm) = rule { + if let Ok(true) = services().rooms.state_cache.is_joined(sender_user, &rm.room_id) { + return Ok(true); + } + } + } - Ok(false) - } - JoinRule::KnockRestricted(_) => { - // TODO: Check rules - Ok(false) - } - _ => Ok(false), - } - } + Ok(false) + }, + JoinRule::KnockRestricted(_) => { + // TODO: Check rules + Ok(false) + }, + _ => Ok(false), + } + } } diff --git a/src/service/rooms/state/data.rs b/src/service/rooms/state/data.rs index 96116b02..b3850b40 100644 --- a/src/service/rooms/state/data.rs +++ b/src/service/rooms/state/data.rs @@ -1,31 +1,33 @@ -use crate::Result; -use ruma::{EventId, OwnedEventId, RoomId}; use std::{collections::HashSet, sync::Arc}; + +use ruma::{EventId, OwnedEventId, RoomId}; use tokio::sync::MutexGuard; +use crate::Result; + pub trait Data: Send + Sync { - /// Returns the last state hash key added to the db for the given room. - fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result>; + /// Returns the last state hash key added to the db for the given room. + fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result>; - /// Set the state hash to a new version, but does not update state_cache. - fn set_room_state( - &self, - room_id: &RoomId, - new_shortstatehash: u64, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()>; + /// Set the state hash to a new version, but does not update state_cache. + fn set_room_state( + &self, + room_id: &RoomId, + new_shortstatehash: u64, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()>; - /// Associates a state with an event. - fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()>; + /// Associates a state with an event. + fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()>; - /// Returns all events we would send as the prev_events of the next event. - fn get_forward_extremities(&self, room_id: &RoomId) -> Result>>; + /// Returns all events we would send as the prev_events of the next event. + fn get_forward_extremities(&self, room_id: &RoomId) -> Result>>; - /// Replace the forward extremities of the room. - fn set_forward_extremities( - &self, - room_id: &RoomId, - event_ids: Vec, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()>; + /// Replace the forward extremities of the room. + fn set_forward_extremities( + &self, + room_id: &RoomId, + event_ids: Vec, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()>; } diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index a056a065..f7024fa9 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -1,430 +1,329 @@ mod data; use std::{ - collections::{HashMap, HashSet}, - sync::Arc, + collections::{HashMap, HashSet}, + sync::Arc, }; pub use data::Data; use ruma::{ - api::client::error::ErrorKind, - events::{ - room::{create::RoomCreateEventContent, member::RoomMemberEventContent}, - AnyStrippedStateEvent, StateEventType, TimelineEventType, - }, - serde::Raw, - state_res::{self, StateMap}, - EventId, OwnedEventId, RoomId, RoomVersionId, UserId, + api::client::error::ErrorKind, + events::{ + room::{create::RoomCreateEventContent, member::RoomMemberEventContent}, + AnyStrippedStateEvent, StateEventType, TimelineEventType, + }, + serde::Raw, + state_res::{self, StateMap}, + EventId, OwnedEventId, RoomId, RoomVersionId, UserId, }; use tokio::sync::MutexGuard; use tracing::warn; +use super::state_compressor::CompressedStateEvent; use crate::{services, utils::calculate_hash, Error, PduEvent, Result}; -use super::state_compressor::CompressedStateEvent; - pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - /// Set the room to the given statehash and update caches. - pub async fn force_state( - &self, - room_id: &RoomId, - shortstatehash: u64, - statediffnew: Arc>, - _statediffremoved: Arc>, - state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - for event_id in statediffnew.iter().filter_map(|new| { - services() - .rooms - .state_compressor - .parse_compressed_state_event(new) - .ok() - .map(|(_, id)| id) - }) { - let pdu = match services().rooms.timeline.get_pdu_json(&event_id)? { - Some(pdu) => pdu, - None => continue, - }; + /// Set the room to the given statehash and update caches. + pub async fn force_state( + &self, + room_id: &RoomId, + shortstatehash: u64, + statediffnew: Arc>, + _statediffremoved: Arc>, + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + for event_id in statediffnew.iter().filter_map(|new| { + services().rooms.state_compressor.parse_compressed_state_event(new).ok().map(|(_, id)| id) + }) { + let pdu = match services().rooms.timeline.get_pdu_json(&event_id)? { + Some(pdu) => pdu, + None => continue, + }; - let pdu: PduEvent = match serde_json::from_str( - &serde_json::to_string(&pdu).expect("CanonicalJsonObj can be serialized to JSON"), - ) { - Ok(pdu) => pdu, - Err(_) => continue, - }; + let pdu: PduEvent = match serde_json::from_str( + &serde_json::to_string(&pdu).expect("CanonicalJsonObj can be serialized to JSON"), + ) { + Ok(pdu) => pdu, + Err(_) => continue, + }; - match pdu.kind { - TimelineEventType::RoomMember => { - let membership_event = - match serde_json::from_str::(pdu.content.get()) { - Ok(e) => e, - Err(_) => continue, - }; + match pdu.kind { + TimelineEventType::RoomMember => { + let membership_event = match serde_json::from_str::(pdu.content.get()) { + Ok(e) => e, + Err(_) => continue, + }; - let state_key = match pdu.state_key { - Some(k) => k, - None => continue, - }; + let state_key = match pdu.state_key { + Some(k) => k, + None => continue, + }; - let user_id = match UserId::parse(state_key) { - Ok(id) => id, - Err(_) => continue, - }; + let user_id = match UserId::parse(state_key) { + Ok(id) => id, + Err(_) => continue, + }; - services() - .rooms - .state_cache - .update_membership( - room_id, - &user_id, - membership_event, - &pdu.sender, - None, - false, - ) - .await?; - } - TimelineEventType::SpaceChild => { - services() - .rooms - .spaces - .roomid_spacechunk_cache - .lock() - .unwrap() - .remove(&pdu.room_id); - } - _ => continue, - } - } + services() + .rooms + .state_cache + .update_membership(room_id, &user_id, membership_event, &pdu.sender, None, false) + .await?; + }, + TimelineEventType::SpaceChild => { + services().rooms.spaces.roomid_spacechunk_cache.lock().unwrap().remove(&pdu.room_id); + }, + _ => continue, + } + } - services().rooms.state_cache.update_joined_count(room_id)?; + services().rooms.state_cache.update_joined_count(room_id)?; - self.db - .set_room_state(room_id, shortstatehash, state_lock)?; + self.db.set_room_state(room_id, shortstatehash, state_lock)?; - Ok(()) - } + Ok(()) + } - /// Generates a new StateHash and associates it with the incoming event. - /// - /// This adds all current state events (not including the incoming event) - /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, state_ids_compressed))] - pub fn set_event_state( - &self, - event_id: &EventId, - room_id: &RoomId, - state_ids_compressed: Arc>, - ) -> Result { - let shorteventid = services() - .rooms - .short - .get_or_create_shorteventid(event_id)?; + /// Generates a new StateHash and associates it with the incoming event. + /// + /// This adds all current state events (not including the incoming event) + /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. + #[tracing::instrument(skip(self, state_ids_compressed))] + pub fn set_event_state( + &self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc>, + ) -> Result { + let shorteventid = services().rooms.short.get_or_create_shorteventid(event_id)?; - let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; + let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?; - let state_hash = calculate_hash( - &state_ids_compressed - .iter() - .map(|s| &s[..]) - .collect::>(), - ); + let state_hash = calculate_hash(&state_ids_compressed.iter().map(|s| &s[..]).collect::>()); - let (shortstatehash, already_existed) = services() - .rooms - .short - .get_or_create_shortstatehash(&state_hash)?; + let (shortstatehash, already_existed) = services().rooms.short.get_or_create_shortstatehash(&state_hash)?; - if !already_existed { - let states_parents = previous_shortstatehash.map_or_else( - || Ok(Vec::new()), - |p| { - services() - .rooms - .state_compressor - .load_shortstatehash_info(p) - }, - )?; + if !already_existed { + let states_parents = previous_shortstatehash.map_or_else( + || Ok(Vec::new()), + |p| services().rooms.state_compressor.load_shortstatehash_info(p), + )?; - let (statediffnew, statediffremoved) = - if let Some(parent_stateinfo) = states_parents.last() { - let statediffnew: HashSet<_> = state_ids_compressed - .difference(&parent_stateinfo.1) - .copied() - .collect(); + let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew: HashSet<_> = state_ids_compressed.difference(&parent_stateinfo.1).copied().collect(); - let statediffremoved: HashSet<_> = parent_stateinfo - .1 - .difference(&state_ids_compressed) - .copied() - .collect(); + let statediffremoved: HashSet<_> = + parent_stateinfo.1.difference(&state_ids_compressed).copied().collect(); - (Arc::new(statediffnew), Arc::new(statediffremoved)) - } else { - (state_ids_compressed, Arc::new(HashSet::new())) - }; - services().rooms.state_compressor.save_state_from_diff( - shortstatehash, - statediffnew, - statediffremoved, - 1_000_000, // high number because no state will be based on this one - states_parents, - )?; - } + (Arc::new(statediffnew), Arc::new(statediffremoved)) + } else { + (state_ids_compressed, Arc::new(HashSet::new())) + }; + services().rooms.state_compressor.save_state_from_diff( + shortstatehash, + statediffnew, + statediffremoved, + 1_000_000, // high number because no state will be based on this one + states_parents, + )?; + } - self.db.set_event_state(shorteventid, shortstatehash)?; + self.db.set_event_state(shorteventid, shortstatehash)?; - Ok(shortstatehash) - } + Ok(shortstatehash) + } - /// Generates a new StateHash and associates it with the incoming event. - /// - /// This adds all current state events (not including the incoming event) - /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. - #[tracing::instrument(skip(self, new_pdu))] - pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { - let shorteventid = services() - .rooms - .short - .get_or_create_shorteventid(&new_pdu.event_id)?; + /// Generates a new StateHash and associates it with the incoming event. + /// + /// This adds all current state events (not including the incoming event) + /// to `stateid_pduid` and adds the incoming event to `eventid_statehash`. + #[tracing::instrument(skip(self, new_pdu))] + pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result { + let shorteventid = services().rooms.short.get_or_create_shorteventid(&new_pdu.event_id)?; - let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; + let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?; - if let Some(p) = previous_shortstatehash { - self.db.set_event_state(shorteventid, p)?; - } + if let Some(p) = previous_shortstatehash { + self.db.set_event_state(shorteventid, p)?; + } - if let Some(state_key) = &new_pdu.state_key { - let states_parents = previous_shortstatehash.map_or_else( - || Ok(Vec::new()), - |p| { - services() - .rooms - .state_compressor - .load_shortstatehash_info(p) - }, - )?; + if let Some(state_key) = &new_pdu.state_key { + let states_parents = previous_shortstatehash.map_or_else( + || Ok(Vec::new()), + |p| services().rooms.state_compressor.load_shortstatehash_info(p), + )?; - let shortstatekey = services() - .rooms - .short - .get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; + let shortstatekey = + services().rooms.short.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?; - let new = services() - .rooms - .state_compressor - .compress_state_event(shortstatekey, &new_pdu.event_id)?; + let new = services().rooms.state_compressor.compress_state_event(shortstatekey, &new_pdu.event_id)?; - let replaces = states_parents - .last() - .map(|info| { - info.1 - .iter() - .find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())) - }) - .unwrap_or_default(); + let replaces = states_parents + .last() + .map(|info| info.1.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))) + .unwrap_or_default(); - if Some(&new) == replaces { - return Ok(previous_shortstatehash.expect("must exist")); - } + if Some(&new) == replaces { + return Ok(previous_shortstatehash.expect("must exist")); + } - // TODO: statehash with deterministic inputs - let shortstatehash = services().globals.next_count()?; + // TODO: statehash with deterministic inputs + let shortstatehash = services().globals.next_count()?; - let mut statediffnew = HashSet::new(); - statediffnew.insert(new); + let mut statediffnew = HashSet::new(); + statediffnew.insert(new); - let mut statediffremoved = HashSet::new(); - if let Some(replaces) = replaces { - statediffremoved.insert(*replaces); - } + let mut statediffremoved = HashSet::new(); + if let Some(replaces) = replaces { + statediffremoved.insert(*replaces); + } - services().rooms.state_compressor.save_state_from_diff( - shortstatehash, - Arc::new(statediffnew), - Arc::new(statediffremoved), - 2, - states_parents, - )?; + services().rooms.state_compressor.save_state_from_diff( + shortstatehash, + Arc::new(statediffnew), + Arc::new(statediffremoved), + 2, + states_parents, + )?; - Ok(shortstatehash) - } else { - Ok(previous_shortstatehash.expect("first event in room must be a state event")) - } - } + Ok(shortstatehash) + } else { + Ok(previous_shortstatehash.expect("first event in room must be a state event")) + } + } - #[tracing::instrument(skip(self, invite_event))] - pub fn calculate_invite_state( - &self, - invite_event: &PduEvent, - ) -> Result>> { - let mut state = Vec::new(); - // Add recommended events - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCreate, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomJoinRules, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomCanonicalAlias, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomAvatar, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomName, - "", - )? { - state.push(e.to_stripped_state_event()); - } - if let Some(e) = services().rooms.state_accessor.room_state_get( - &invite_event.room_id, - &StateEventType::RoomMember, - invite_event.sender.as_str(), - )? { - state.push(e.to_stripped_state_event()); - } + #[tracing::instrument(skip(self, invite_event))] + pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result>> { + let mut state = Vec::new(); + // Add recommended events + if let Some(e) = + services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")? + { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = + services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")? + { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = services().rooms.state_accessor.room_state_get( + &invite_event.room_id, + &StateEventType::RoomCanonicalAlias, + "", + )? { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = + services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")? + { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = + services().rooms.state_accessor.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")? + { + state.push(e.to_stripped_state_event()); + } + if let Some(e) = services().rooms.state_accessor.room_state_get( + &invite_event.room_id, + &StateEventType::RoomMember, + invite_event.sender.as_str(), + )? { + state.push(e.to_stripped_state_event()); + } - state.push(invite_event.to_stripped_state_event()); - Ok(state) - } + state.push(invite_event.to_stripped_state_event()); + Ok(state) + } - /// Set the state hash to a new version, but does not update state_cache. - #[tracing::instrument(skip(self))] - pub fn set_room_state( - &self, - room_id: &RoomId, - shortstatehash: u64, - mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - self.db.set_room_state(room_id, shortstatehash, mutex_lock) - } + /// Set the state hash to a new version, but does not update state_cache. + #[tracing::instrument(skip(self))] + pub fn set_room_state( + &self, + room_id: &RoomId, + shortstatehash: u64, + mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + self.db.set_room_state(room_id, shortstatehash, mutex_lock) + } - /// Returns the room's version. - #[tracing::instrument(skip(self))] - pub fn get_room_version(&self, room_id: &RoomId) -> Result { - let create_event = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomCreate, - "", - )?; + /// Returns the room's version. + #[tracing::instrument(skip(self))] + pub fn get_room_version(&self, room_id: &RoomId) -> Result { + let create_event = services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomCreate, "")?; - let create_event_content: RoomCreateEventContent = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "No create event found"))?; + let create_event_content: RoomCreateEventContent = create_event + .as_ref() + .map(|create_event| { + serde_json::from_str(create_event.content.get()).map_err(|e| { + warn!("Invalid create event: {}", e); + Error::bad_database("Invalid create event in db.") + }) + }) + .transpose()? + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "No create event found"))?; - Ok(create_event_content.room_version) - } + Ok(create_event_content.room_version) + } - pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { - self.db.get_room_shortstatehash(room_id) - } + pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result> { + self.db.get_room_shortstatehash(room_id) + } - pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { - self.db.get_forward_extremities(room_id) - } + pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result>> { + self.db.get_forward_extremities(room_id) + } - pub fn set_forward_extremities( - &self, - room_id: &RoomId, - event_ids: Vec, - state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<()> { - self.db - .set_forward_extremities(room_id, event_ids, state_lock) - } + pub fn set_forward_extremities( + &self, + room_id: &RoomId, + event_ids: Vec, + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<()> { + self.db.set_forward_extremities(room_id, event_ids, state_lock) + } - /// This fetches auth events from the current state. - #[tracing::instrument(skip(self))] - pub fn get_auth_events( - &self, - room_id: &RoomId, - kind: &TimelineEventType, - sender: &UserId, - state_key: Option<&str>, - content: &serde_json::value::RawValue, - ) -> Result>> { - let shortstatehash = if let Some(current_shortstatehash) = - services().rooms.state.get_room_shortstatehash(room_id)? - { - current_shortstatehash - } else { - return Ok(HashMap::new()); - }; + /// This fetches auth events from the current state. + #[tracing::instrument(skip(self))] + pub fn get_auth_events( + &self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>, + content: &serde_json::value::RawValue, + ) -> Result>> { + let shortstatehash = + if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? { + current_shortstatehash + } else { + return Ok(HashMap::new()); + }; - let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content) - .expect("content is a valid JSON object"); + let auth_events = + state_res::auth_types_for_event(kind, sender, state_key, content).expect("content is a valid JSON object"); - let mut sauthevents = auth_events - .into_iter() - .filter_map(|(event_type, state_key)| { - services() - .rooms - .short - .get_shortstatekey(&event_type.to_string().into(), &state_key) - .ok() - .flatten() - .map(|s| (s, (event_type, state_key))) - }) - .collect::>(); + let mut sauthevents = auth_events + .into_iter() + .filter_map(|(event_type, state_key)| { + services() + .rooms + .short + .get_shortstatekey(&event_type.to_string().into(), &state_key) + .ok() + .flatten() + .map(|s| (s, (event_type, state_key))) + }) + .collect::>(); - let full_state = services() - .rooms - .state_compressor - .load_shortstatehash_info(shortstatehash)? - .pop() - .expect("there is always one layer") - .1; + let full_state = services() + .rooms + .state_compressor + .load_shortstatehash_info(shortstatehash)? + .pop() + .expect("there is always one layer") + .1; - Ok(full_state - .iter() - .filter_map(|compressed| { - services() - .rooms - .state_compressor - .parse_compressed_state_event(compressed) - .ok() - }) - .filter_map(|(shortstatekey, event_id)| { - sauthevents.remove(&shortstatekey).map(|k| (k, event_id)) - }) - .filter_map(|(k, event_id)| { - services() - .rooms - .timeline - .get_pdu(&event_id) - .ok() - .flatten() - .map(|pdu| (k, pdu)) - }) - .collect()) - } + Ok(full_state + .iter() + .filter_map(|compressed| services().rooms.state_compressor.parse_compressed_state_event(compressed).ok()) + .filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id))) + .filter_map(|(k, event_id)| services().rooms.timeline.get_pdu(&event_id).ok().flatten().map(|pdu| (k, pdu))) + .collect()) + } } diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index f3ae3c21..8e046df3 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -7,53 +7,39 @@ use crate::{PduEvent, Result}; #[async_trait] pub trait Data: Send + Sync { - /// Builds a StateMap by iterating over all keys that start - /// with state_hash, this gives the full state for the given state_hash. - async fn state_full_ids(&self, shortstatehash: u64) -> Result>>; + /// Builds a StateMap by iterating over all keys that start + /// with state_hash, this gives the full state for the given state_hash. + async fn state_full_ids(&self, shortstatehash: u64) -> Result>>; - async fn state_full( - &self, - shortstatehash: u64, - ) -> Result>>; + async fn state_full(&self, shortstatehash: u64) -> Result>>; - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - fn state_get_id( - &self, - shortstatehash: u64, - event_type: &StateEventType, - state_key: &str, - ) -> Result>>; + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn state_get_id( + &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + ) -> Result>>; - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - fn state_get( - &self, - shortstatehash: u64, - event_type: &StateEventType, - state_key: &str, - ) -> Result>>; + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn state_get( + &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + ) -> Result>>; - /// Returns the state hash for this pdu. - fn pdu_shortstatehash(&self, event_id: &EventId) -> Result>; + /// Returns the state hash for this pdu. + fn pdu_shortstatehash(&self, event_id: &EventId) -> Result>; - /// Returns the full room state. - async fn room_state_full( - &self, - room_id: &RoomId, - ) -> Result>>; + /// Returns the full room state. + async fn room_state_full(&self, room_id: &RoomId) -> Result>>; - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - fn room_state_get_id( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result>>; + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn room_state_get_id( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result>>; - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - fn room_state_get( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result>>; + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + fn room_state_get( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result>>; } diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index f6e3376c..84fc3a28 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -1,316 +1,259 @@ mod data; use std::{ - collections::HashMap, - sync::{Arc, Mutex}, + collections::HashMap, + sync::{Arc, Mutex}, }; pub use data::Data; use lru_cache::LruCache; use ruma::{ - events::{ - room::{ - avatar::RoomAvatarEventContent, - history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, - member::{MembershipState, RoomMemberEventContent}, - name::RoomNameEventContent, - }, - StateEventType, - }, - EventId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + events::{ + room::{ + avatar::RoomAvatarEventContent, + history_visibility::{HistoryVisibility, RoomHistoryVisibilityEventContent}, + member::{MembershipState, RoomMemberEventContent}, + name::RoomNameEventContent, + }, + StateEventType, + }, + EventId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; use tracing::error; use crate::{services, Error, PduEvent, Result}; pub struct Service { - pub db: &'static dyn Data, - pub server_visibility_cache: Mutex>, - pub user_visibility_cache: Mutex>, + pub db: &'static dyn Data, + pub server_visibility_cache: Mutex>, + pub user_visibility_cache: Mutex>, } impl Service { - /// Builds a StateMap by iterating over all keys that start - /// with state_hash, this gives the full state for the given state_hash. - #[tracing::instrument(skip(self))] - pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { - self.db.state_full_ids(shortstatehash).await - } + /// Builds a StateMap by iterating over all keys that start + /// with state_hash, this gives the full state for the given state_hash. + #[tracing::instrument(skip(self))] + pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + self.db.state_full_ids(shortstatehash).await + } - pub async fn state_full( - &self, - shortstatehash: u64, - ) -> Result>> { - self.db.state_full(shortstatehash).await - } + pub async fn state_full(&self, shortstatehash: u64) -> Result>> { + self.db.state_full(shortstatehash).await + } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn state_get_id( - &self, - shortstatehash: u64, - event_type: &StateEventType, - state_key: &str, - ) -> Result>> { - self.db.state_get_id(shortstatehash, event_type, state_key) - } + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + #[tracing::instrument(skip(self))] + pub fn state_get_id( + &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + self.db.state_get_id(shortstatehash, event_type, state_key) + } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - pub fn state_get( - &self, - shortstatehash: u64, - event_type: &StateEventType, - state_key: &str, - ) -> Result>> { - self.db.state_get(shortstatehash, event_type, state_key) - } + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + pub fn state_get( + &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + self.db.state_get(shortstatehash, event_type, state_key) + } - /// Get membership for given user in state - fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result { - self.state_get( - shortstatehash, - &StateEventType::RoomMember, - user_id.as_str(), - )? - .map_or(Ok(MembershipState::Leave), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomMemberEventContent| c.membership) - .map_err(|_| Error::bad_database("Invalid room membership event in database.")) - }) - } + /// Get membership for given user in state + fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result { + self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())?.map_or( + Ok(MembershipState::Leave), + |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomMemberEventContent| c.membership) + .map_err(|_| Error::bad_database("Invalid room membership event in database.")) + }, + ) + } - /// The user was a joined member at this state (potentially in the past) - fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id) - .is_ok_and(|s| s == MembershipState::Join) // Return sensible default, i.e. false - } + /// The user was a joined member at this state (potentially in the past) + fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { + self.user_membership(shortstatehash, user_id).is_ok_and(|s| s == MembershipState::Join) + // Return sensible default, i.e. + // false + } - /// The user was an invited or joined room member at this state (potentially - /// in the past) - fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { - self.user_membership(shortstatehash, user_id) - .is_ok_and(|s| s == MembershipState::Join || s == MembershipState::Invite) - // Return sensible default, i.e. false - } + /// The user was an invited or joined room member at this state (potentially + /// in the past) + fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { + self.user_membership(shortstatehash, user_id) + .is_ok_and(|s| s == MembershipState::Join || s == MembershipState::Invite) + // Return sensible default, i.e. false + } - /// Whether a server is allowed to see an event through federation, based on - /// the room's history_visibility at that event's state. - #[tracing::instrument(skip(self, origin, room_id, event_id))] - pub fn server_can_see_event( - &self, - origin: &ServerName, - room_id: &RoomId, - event_id: &EventId, - ) -> Result { - let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { - return Ok(true); - }; + /// Whether a server is allowed to see an event through federation, based on + /// the room's history_visibility at that event's state. + #[tracing::instrument(skip(self, origin, room_id, event_id))] + pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> Result { + let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else { + return Ok(true); + }; - if let Some(visibility) = self - .server_visibility_cache - .lock() - .unwrap() - .get_mut(&(origin.to_owned(), shortstatehash)) - { - return Ok(*visibility); - } + if let Some(visibility) = + self.server_visibility_cache.lock().unwrap().get_mut(&(origin.to_owned(), shortstatehash)) + { + return Ok(*visibility); + } - let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) - .map_err(|_| { - Error::bad_database("Invalid history visibility event in database.") - }) - })?; + let history_visibility = self.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?.map_or( + Ok(HistoryVisibility::Shared), + |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) + .map_err(|_| Error::bad_database("Invalid history visibility event in database.")) + }, + )?; - let mut current_server_members = services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(std::result::Result::ok) - .filter(|member| member.server_name() == origin); + let mut current_server_members = services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(std::result::Result::ok) + .filter(|member| member.server_name() == origin); - let visibility = match history_visibility { - HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true, - HistoryVisibility::Invited => { - // Allow if any member on requesting server was AT LEAST invited, else deny - current_server_members.any(|member| self.user_was_invited(shortstatehash, &member)) - } - HistoryVisibility::Joined => { - // Allow if any member on requested server was joined, else deny - current_server_members.any(|member| self.user_was_joined(shortstatehash, &member)) - } - _ => { - error!("Unknown history visibility {history_visibility}"); - false - } - }; + let visibility = match history_visibility { + HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true, + HistoryVisibility::Invited => { + // Allow if any member on requesting server was AT LEAST invited, else deny + current_server_members.any(|member| self.user_was_invited(shortstatehash, &member)) + }, + HistoryVisibility::Joined => { + // Allow if any member on requested server was joined, else deny + current_server_members.any(|member| self.user_was_joined(shortstatehash, &member)) + }, + _ => { + error!("Unknown history visibility {history_visibility}"); + false + }, + }; - self.server_visibility_cache - .lock() - .unwrap() - .insert((origin.to_owned(), shortstatehash), visibility); + self.server_visibility_cache.lock().unwrap().insert((origin.to_owned(), shortstatehash), visibility); - Ok(visibility) - } + Ok(visibility) + } - /// Whether a user is allowed to see an event, based on - /// the room's history_visibility at that event's state. - #[tracing::instrument(skip(self, user_id, room_id, event_id))] - pub fn user_can_see_event( - &self, - user_id: &UserId, - room_id: &RoomId, - event_id: &EventId, - ) -> Result { - let shortstatehash = match self.pdu_shortstatehash(event_id)? { - Some(shortstatehash) => shortstatehash, - None => return Ok(true), - }; + /// Whether a user is allowed to see an event, based on + /// the room's history_visibility at that event's state. + #[tracing::instrument(skip(self, user_id, room_id, event_id))] + pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> Result { + let shortstatehash = match self.pdu_shortstatehash(event_id)? { + Some(shortstatehash) => shortstatehash, + None => return Ok(true), + }; - if let Some(visibility) = self - .user_visibility_cache - .lock() - .unwrap() - .get_mut(&(user_id.to_owned(), shortstatehash)) - { - return Ok(*visibility); - } + if let Some(visibility) = + self.user_visibility_cache.lock().unwrap().get_mut(&(user_id.to_owned(), shortstatehash)) + { + return Ok(*visibility); + } - let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; + let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; - let history_visibility = self - .state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) - .map_err(|_| { - Error::bad_database("Invalid history visibility event in database.") - }) - })?; + let history_visibility = self.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?.map_or( + Ok(HistoryVisibility::Shared), + |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) + .map_err(|_| Error::bad_database("Invalid history visibility event in database.")) + }, + )?; - let visibility = match history_visibility { - HistoryVisibility::WorldReadable => true, - HistoryVisibility::Shared => currently_member, - HistoryVisibility::Invited => { - // Allow if any member on requesting server was AT LEAST invited, else deny - self.user_was_invited(shortstatehash, user_id) - } - HistoryVisibility::Joined => { - // Allow if any member on requested server was joined, else deny - self.user_was_joined(shortstatehash, user_id) - } - _ => { - error!("Unknown history visibility {history_visibility}"); - false - } - }; + let visibility = match history_visibility { + HistoryVisibility::WorldReadable => true, + HistoryVisibility::Shared => currently_member, + HistoryVisibility::Invited => { + // Allow if any member on requesting server was AT LEAST invited, else deny + self.user_was_invited(shortstatehash, user_id) + }, + HistoryVisibility::Joined => { + // Allow if any member on requested server was joined, else deny + self.user_was_joined(shortstatehash, user_id) + }, + _ => { + error!("Unknown history visibility {history_visibility}"); + false + }, + }; - self.user_visibility_cache - .lock() - .unwrap() - .insert((user_id.to_owned(), shortstatehash), visibility); + self.user_visibility_cache.lock().unwrap().insert((user_id.to_owned(), shortstatehash), visibility); - Ok(visibility) - } + Ok(visibility) + } - /// Whether a user is allowed to see an event, based on - /// the room's history_visibility at that event's state. - #[tracing::instrument(skip(self, user_id, room_id))] - pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result { - let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; + /// Whether a user is allowed to see an event, based on + /// the room's history_visibility at that event's state. + #[tracing::instrument(skip(self, user_id, room_id))] + pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result { + let currently_member = services().rooms.state_cache.is_joined(user_id, room_id)?; - let history_visibility = self - .room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")? - .map_or(Ok(HistoryVisibility::Shared), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) - .map_err(|_| { - Error::bad_database("Invalid history visibility event in database.") - }) - })?; + let history_visibility = self.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?.map_or( + Ok(HistoryVisibility::Shared), + |s| { + serde_json::from_str(s.content.get()) + .map(|c: RoomHistoryVisibilityEventContent| c.history_visibility) + .map_err(|_| Error::bad_database("Invalid history visibility event in database.")) + }, + )?; - Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable) - } + Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable) + } - /// Returns the state hash for this pdu. - pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { - self.db.pdu_shortstatehash(event_id) - } + /// Returns the state hash for this pdu. + pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result> { self.db.pdu_shortstatehash(event_id) } - /// Returns the full room state. - #[tracing::instrument(skip(self))] - pub async fn room_state_full( - &self, - room_id: &RoomId, - ) -> Result>> { - self.db.room_state_full(room_id).await - } + /// Returns the full room state. + #[tracing::instrument(skip(self))] + pub async fn room_state_full(&self, room_id: &RoomId) -> Result>> { + self.db.room_state_full(room_id).await + } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn room_state_get_id( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result>> { - self.db.room_state_get_id(room_id, event_type, state_key) - } + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + #[tracing::instrument(skip(self))] + pub fn room_state_get_id( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + self.db.room_state_get_id(room_id, event_type, state_key) + } - /// Returns a single PDU from `room_id` with key (`event_type`, `state_key`). - #[tracing::instrument(skip(self))] - pub fn room_state_get( - &self, - room_id: &RoomId, - event_type: &StateEventType, - state_key: &str, - ) -> Result>> { - self.db.room_state_get(room_id, event_type, state_key) - } + /// Returns a single PDU from `room_id` with key (`event_type`, + /// `state_key`). + #[tracing::instrument(skip(self))] + pub fn room_state_get( + &self, room_id: &RoomId, event_type: &StateEventType, state_key: &str, + ) -> Result>> { + self.db.room_state_get(room_id, event_type, state_key) + } - pub fn get_name(&self, room_id: &RoomId) -> Result> { - services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomName, "")? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map(|c: RoomNameEventContent| Some(c.name)) - .map_err(|e| { - error!( - "Invalid room name event in database for room {}. {}", - room_id, e - ); - Error::bad_database("Invalid room name event in database.") - }) - }) - } + pub fn get_name(&self, room_id: &RoomId) -> Result> { + services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomName, "")?.map_or(Ok(None), |s| { + serde_json::from_str(s.content.get()).map(|c: RoomNameEventContent| Some(c.name)).map_err(|e| { + error!("Invalid room name event in database for room {}. {}", room_id, e); + Error::bad_database("Invalid room name event in database.") + }) + }) + } - pub fn get_avatar(&self, room_id: &RoomId) -> Result> { - services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomAvatar, "")? - .map_or(Ok(ruma::JsOption::Undefined), |s| { - serde_json::from_str(s.content.get()) - .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) - }) - } + pub fn get_avatar(&self, room_id: &RoomId) -> Result> { + services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomAvatar, "")?.map_or( + Ok(ruma::JsOption::Undefined), + |s| { + serde_json::from_str(s.content.get()) + .map_err(|_| Error::bad_database("Invalid room avatar event in database.")) + }, + ) + } - pub fn get_member( - &self, - room_id: &RoomId, - user_id: &UserId, - ) -> Result> { - services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())? - .map_or(Ok(None), |s| { - serde_json::from_str(s.content.get()) - .map_err(|_| Error::bad_database("Invalid room member event in database.")) - }) - } + pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?.map_or( + Ok(None), + |s| { + serde_json::from_str(s.content.get()) + .map_err(|_| Error::bad_database("Invalid room member event in database.")) + }, + ) + } } diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index e36ac03e..44ae57f8 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -1,112 +1,79 @@ use std::{collections::HashSet, sync::Arc}; -use crate::Result; use ruma::{ - api::appservice::Registration, - events::{AnyStrippedStateEvent, AnySyncStateEvent}, - serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + api::appservice::Registration, + events::{AnyStrippedStateEvent, AnySyncStateEvent}, + serde::Raw, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -type StrippedStateEventIter<'a> = - Box>)>> + 'a>; +use crate::Result; -type AnySyncStateEventIter<'a> = - Box>)>> + 'a>; +type StrippedStateEventIter<'a> = Box>)>> + 'a>; + +type AnySyncStateEventIter<'a> = Box>)>> + 'a>; pub trait Data: Send + Sync { - fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - fn mark_as_invited( - &self, - user_id: &UserId, - room_id: &RoomId, - last_state: Option>>, - ) -> Result<()>; - fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn mark_as_invited( + &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, + ) -> Result<()>; + fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; + fn update_joined_count(&self, room_id: &RoomId) -> Result<()>; - fn get_our_real_users(&self, room_id: &RoomId) -> Result>>; + fn get_our_real_users(&self, room_id: &RoomId) -> Result>>; - fn appservice_in_room( - &self, - room_id: &RoomId, - appservice: &(String, Registration), - ) -> Result; + fn appservice_in_room(&self, room_id: &RoomId, appservice: &(String, Registration)) -> Result; - /// Makes a user forget a room. - fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()>; + /// Makes a user forget a room. + fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()>; - /// Returns an iterator of all servers participating in this room. - fn room_servers<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a>; + /// Returns an iterator of all servers participating in this room. + fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; - fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result; + fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result; - /// Returns an iterator of all rooms a server participates in (as far as we know). - fn server_rooms<'a>( - &'a self, - server: &ServerName, - ) -> Box> + 'a>; + /// Returns an iterator of all rooms a server participates in (as far as we + /// know). + fn server_rooms<'a>(&'a self, server: &ServerName) -> Box> + 'a>; - /// Returns an iterator over all joined members of a room. - fn room_members<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a>; + /// Returns an iterator over all joined members of a room. + fn room_members<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; - fn room_joined_count(&self, room_id: &RoomId) -> Result>; + fn room_joined_count(&self, room_id: &RoomId) -> Result>; - fn room_invited_count(&self, room_id: &RoomId) -> Result>; + fn room_invited_count(&self, room_id: &RoomId) -> Result>; - /// Returns an iterator over all User IDs who ever joined a room. - fn room_useroncejoined<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a>; + /// Returns an iterator over all User IDs who ever joined a room. + fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; - /// Returns an iterator over all invited members of a room. - fn room_members_invited<'a>( - &'a self, - room_id: &RoomId, - ) -> Box> + 'a>; + /// Returns an iterator over all invited members of a room. + fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; - fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; - fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; + fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result>; - /// Returns an iterator over all rooms this user joined. - fn rooms_joined<'a>( - &'a self, - user_id: &UserId, - ) -> Box> + 'a>; + /// Returns an iterator over all rooms this user joined. + fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box> + 'a>; - /// Returns an iterator over all rooms a user was invited to. - fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a>; + /// Returns an iterator over all rooms a user was invited to. + fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a>; - fn invite_state( - &self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result>>>; + fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>>; - fn left_state( - &self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result>>>; + fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>>; - /// Returns an iterator over all rooms a user left. - fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a>; + /// Returns an iterator over all rooms a user left. + fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a>; - fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result; - fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result; - fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result; - fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result; } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 8e12f817..7f305712 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,393 +1,311 @@ use std::{collections::HashSet, sync::Arc}; -use ruma::api::federation; +pub use data::Data; use ruma::{ - api::appservice::Registration, - events::{ - direct::DirectEvent, - ignored_user_list::IgnoredUserListEvent, - room::{ - create::RoomCreateEventContent, - member::{MembershipState, RoomMemberEventContent}, - }, - AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, - RoomAccountDataEventType, StateEventType, - }, - serde::Raw, - OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, + api::{appservice::Registration, federation}, + events::{ + direct::DirectEvent, + ignored_user_list::IgnoredUserListEvent, + room::{ + create::RoomCreateEventContent, + member::{MembershipState, RoomMemberEventContent}, + }, + AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType, + }, + serde::Raw, + OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; use tracing::warn; -pub use data::Data; - use crate::{services, Error, Result}; mod data; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - /// Update current membership data. - #[tracing::instrument(skip(self, last_state))] - pub async fn update_membership( - &self, - room_id: &RoomId, - user_id: &UserId, - membership_event: RoomMemberEventContent, - sender: &UserId, - last_state: Option>>, - update_joined_count: bool, - ) -> Result<()> { - let membership = membership_event.membership; + /// Update current membership data. + #[tracing::instrument(skip(self, last_state))] + pub async fn update_membership( + &self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId, + last_state: Option>>, update_joined_count: bool, + ) -> Result<()> { + let membership = membership_event.membership; - // Keep track what remote users exist by adding them as "deactivated" users - if user_id.server_name() != services().globals.server_name() { - if !services().users.exists(user_id)? { - services().users.create(user_id, None)?; - } + // Keep track what remote users exist by adding them as "deactivated" users + if user_id.server_name() != services().globals.server_name() { + if !services().users.exists(user_id)? { + services().users.create(user_id, None)?; + } - // Try to update our local copy of the user if ours does not match - if ((services().users.displayname(user_id)? != membership_event.displayname) - || (services().users.avatar_url(user_id)? != membership_event.avatar_url) - || (services().users.blurhash(user_id)? != membership_event.blurhash)) - && (membership != MembershipState::Leave) - { - let response = services() - .sending - .send_federation_request( - user_id.server_name(), - federation::query::get_profile_information::v1::Request { - user_id: user_id.into(), - field: None, // we want the full user's profile to update locally too - }, - ) - .await?; + // Try to update our local copy of the user if ours does not match + if ((services().users.displayname(user_id)? != membership_event.displayname) + || (services().users.avatar_url(user_id)? != membership_event.avatar_url) + || (services().users.blurhash(user_id)? != membership_event.blurhash)) + && (membership != MembershipState::Leave) + { + let response = services() + .sending + .send_federation_request( + user_id.server_name(), + federation::query::get_profile_information::v1::Request { + user_id: user_id.into(), + field: None, // we want the full user's profile to update locally too + }, + ) + .await?; - services() - .users - .set_displayname(user_id, response.displayname.clone()) - .await?; - services() - .users - .set_avatar_url(user_id, response.avatar_url) - .await?; - services() - .users - .set_blurhash(user_id, response.blurhash) - .await?; - }; - } + services().users.set_displayname(user_id, response.displayname.clone()).await?; + services().users.set_avatar_url(user_id, response.avatar_url).await?; + services().users.set_blurhash(user_id, response.blurhash).await?; + }; + } - match &membership { - MembershipState::Join => { - // Check if the user never joined this room - if !self.once_joined(user_id, room_id)? { - // Add the user ID to the join list then - self.db.mark_as_once_joined(user_id, room_id)?; + match &membership { + MembershipState::Join => { + // Check if the user never joined this room + if !self.once_joined(user_id, room_id)? { + // Add the user ID to the join list then + self.db.mark_as_once_joined(user_id, room_id)?; - // Check if the room has a predecessor - if let Some(predecessor) = services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomCreate, "")? - .and_then(|create| serde_json::from_str(create.content.get()).ok()) - .and_then(|content: RoomCreateEventContent| content.predecessor) - { - // Copy user settings from predecessor to the current room: - // - Push rules - // - // TODO: finish this once push rules are implemented. - // - // let mut push_rules_event_content: PushRulesEvent = account_data - // .get( - // None, - // user_id, - // EventType::PushRules, - // )?; - // - // NOTE: find where `predecessor.room_id` match - // and update to `room_id`. - // - // account_data - // .update( - // None, - // user_id, - // EventType::PushRules, - // &push_rules_event_content, - // globals, - // ) - // .ok(); + // Check if the room has a predecessor + if let Some(predecessor) = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomCreate, "")? + .and_then(|create| serde_json::from_str(create.content.get()).ok()) + .and_then(|content: RoomCreateEventContent| content.predecessor) + { + // Copy user settings from predecessor to the current room: + // - Push rules + // + // TODO: finish this once push rules are implemented. + // + // let mut push_rules_event_content: PushRulesEvent = account_data + // .get( + // None, + // user_id, + // EventType::PushRules, + // )?; + // + // NOTE: find where `predecessor.room_id` match + // and update to `room_id`. + // + // account_data + // .update( + // None, + // user_id, + // EventType::PushRules, + // &push_rules_event_content, + // globals, + // ) + // .ok(); - // Copy old tags to new room - if let Some(tag_event) = services() - .account_data - .get( - Some(&predecessor.room_id), - user_id, - RoomAccountDataEventType::Tag, - )? - .map(|event| { - serde_json::from_str(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") - }) - }) - { - services() - .account_data - .update( - Some(room_id), - user_id, - RoomAccountDataEventType::Tag, - &tag_event?, - ) - .ok(); - }; + // Copy old tags to new room + if let Some(tag_event) = services() + .account_data + .get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)? + .map(|event| { + serde_json::from_str(event.get()).map_err(|e| { + warn!("Invalid account data event in db: {e:?}"); + Error::BadDatabase("Invalid account data event in db.") + }) + }) { + services() + .account_data + .update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?) + .ok(); + }; - // Copy direct chat flag - if let Some(direct_event) = services() - .account_data - .get( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - )? - .map(|event| { - serde_json::from_str::(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") - }) - }) - { - let mut direct_event = direct_event?; - let mut room_ids_updated = false; + // Copy direct chat flag + if let Some(direct_event) = services() + .account_data + .get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())? + .map(|event| { + serde_json::from_str::(event.get()).map_err(|e| { + warn!("Invalid account data event in db: {e:?}"); + Error::BadDatabase("Invalid account data event in db.") + }) + }) { + let mut direct_event = direct_event?; + let mut room_ids_updated = false; - for room_ids in direct_event.content.0.values_mut() { - if room_ids.iter().any(|r| r == &predecessor.room_id) { - room_ids.push(room_id.to_owned()); - room_ids_updated = true; - } - } + for room_ids in direct_event.content.0.values_mut() { + if room_ids.iter().any(|r| r == &predecessor.room_id) { + room_ids.push(room_id.to_owned()); + room_ids_updated = true; + } + } - if room_ids_updated { - services().account_data.update( - None, - user_id, - GlobalAccountDataEventType::Direct.to_string().into(), - &serde_json::to_value(&direct_event) - .expect("to json always works"), - )?; - } - }; - } - } + if room_ids_updated { + services().account_data.update( + None, + user_id, + GlobalAccountDataEventType::Direct.to_string().into(), + &serde_json::to_value(&direct_event).expect("to json always works"), + )?; + } + }; + } + } - self.db.mark_as_joined(user_id, room_id)?; - } - MembershipState::Invite => { - // We want to know if the sender is ignored by the receiver - let is_ignored = services() - .account_data - .get( - None, // Ignored users are in global account data - user_id, // Receiver - GlobalAccountDataEventType::IgnoredUserList - .to_string() - .into(), - )? - .map(|event| { - serde_json::from_str::(event.get()).map_err(|e| { - warn!("Invalid account data event in db: {e:?}"); - Error::BadDatabase("Invalid account data event in db.") - }) - }) - .transpose()? - .map_or(false, |ignored| { - ignored - .content - .ignored_users - .iter() - .any(|(user, _details)| user == sender) - }); + self.db.mark_as_joined(user_id, room_id)?; + }, + MembershipState::Invite => { + // We want to know if the sender is ignored by the receiver + let is_ignored = services() + .account_data + .get( + None, // Ignored users are in global account data + user_id, // Receiver + GlobalAccountDataEventType::IgnoredUserList.to_string().into(), + )? + .map(|event| { + serde_json::from_str::(event.get()).map_err(|e| { + warn!("Invalid account data event in db: {e:?}"); + Error::BadDatabase("Invalid account data event in db.") + }) + }) + .transpose()? + .map_or(false, |ignored| { + ignored.content.ignored_users.iter().any(|(user, _details)| user == sender) + }); - if is_ignored { - return Ok(()); - } + if is_ignored { + return Ok(()); + } - self.db.mark_as_invited(user_id, room_id, last_state)?; - } - MembershipState::Leave | MembershipState::Ban => { - self.db.mark_as_left(user_id, room_id)?; - } - _ => {} - } + self.db.mark_as_invited(user_id, room_id, last_state)?; + }, + MembershipState::Leave | MembershipState::Ban => { + self.db.mark_as_left(user_id, room_id)?; + }, + _ => {}, + } - if update_joined_count { - self.update_joined_count(room_id)?; - } + if update_joined_count { + self.update_joined_count(room_id)?; + } - Ok(()) - } + Ok(()) + } - #[tracing::instrument(skip(self, room_id))] - pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { - self.db.update_joined_count(room_id) - } + #[tracing::instrument(skip(self, room_id))] + pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) } - #[tracing::instrument(skip(self, room_id))] - pub fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { - self.db.get_our_real_users(room_id) - } + #[tracing::instrument(skip(self, room_id))] + pub fn get_our_real_users(&self, room_id: &RoomId) -> Result>> { + self.db.get_our_real_users(room_id) + } - #[tracing::instrument(skip(self, room_id, appservice))] - pub fn appservice_in_room( - &self, - room_id: &RoomId, - appservice: &(String, Registration), - ) -> Result { - self.db.appservice_in_room(room_id, appservice) - } + #[tracing::instrument(skip(self, room_id, appservice))] + pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &(String, Registration)) -> Result { + self.db.appservice_in_room(room_id, appservice) + } - /// Makes a user forget a room. - #[tracing::instrument(skip(self))] - pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { - self.db.forget(room_id, user_id) - } + /// Makes a user forget a room. + #[tracing::instrument(skip(self))] + pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) } - /// Returns an iterator of all servers participating in this room. - #[tracing::instrument(skip(self))] - pub fn room_servers<'a>( - &'a self, - room_id: &RoomId, - ) -> impl Iterator> + 'a { - self.db.room_servers(room_id) - } + /// Returns an iterator of all servers participating in this room. + #[tracing::instrument(skip(self))] + pub fn room_servers<'a>(&'a self, room_id: &RoomId) -> impl Iterator> + 'a { + self.db.room_servers(room_id) + } - #[tracing::instrument(skip(self))] - pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { - self.db.server_in_room(server, room_id) - } + #[tracing::instrument(skip(self))] + pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result { + self.db.server_in_room(server, room_id) + } - /// Returns an iterator of all rooms a server participates in (as far as we know). - #[tracing::instrument(skip(self))] - pub fn server_rooms<'a>( - &'a self, - server: &ServerName, - ) -> impl Iterator> + 'a { - self.db.server_rooms(server) - } + /// Returns an iterator of all rooms a server participates in (as far as we + /// know). + #[tracing::instrument(skip(self))] + pub fn server_rooms<'a>(&'a self, server: &ServerName) -> impl Iterator> + 'a { + self.db.server_rooms(server) + } - /// Returns an iterator over all joined members of a room. - #[tracing::instrument(skip(self))] - pub fn room_members<'a>( - &'a self, - room_id: &RoomId, - ) -> impl Iterator> + 'a { - self.db.room_members(room_id) - } + /// Returns an iterator over all joined members of a room. + #[tracing::instrument(skip(self))] + pub fn room_members<'a>(&'a self, room_id: &RoomId) -> impl Iterator> + 'a { + self.db.room_members(room_id) + } - #[tracing::instrument(skip(self))] - pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { - self.db.room_joined_count(room_id) - } + #[tracing::instrument(skip(self))] + pub fn room_joined_count(&self, room_id: &RoomId) -> Result> { self.db.room_joined_count(room_id) } - #[tracing::instrument(skip(self))] - pub fn room_invited_count(&self, room_id: &RoomId) -> Result> { - self.db.room_invited_count(room_id) - } + #[tracing::instrument(skip(self))] + pub fn room_invited_count(&self, room_id: &RoomId) -> Result> { self.db.room_invited_count(room_id) } - /// Returns an iterator over all User IDs who ever joined a room. - #[tracing::instrument(skip(self))] - pub fn room_useroncejoined<'a>( - &'a self, - room_id: &RoomId, - ) -> impl Iterator> + 'a { - self.db.room_useroncejoined(room_id) - } + /// Returns an iterator over all User IDs who ever joined a room. + #[tracing::instrument(skip(self))] + pub fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> impl Iterator> + 'a { + self.db.room_useroncejoined(room_id) + } - /// Returns an iterator over all invited members of a room. - #[tracing::instrument(skip(self))] - pub fn room_members_invited<'a>( - &'a self, - room_id: &RoomId, - ) -> impl Iterator> + 'a { - self.db.room_members_invited(room_id) - } + /// Returns an iterator over all invited members of a room. + #[tracing::instrument(skip(self))] + pub fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> impl Iterator> + 'a { + self.db.room_members_invited(room_id) + } - #[tracing::instrument(skip(self))] - pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.get_invite_count(room_id, user_id) - } + #[tracing::instrument(skip(self))] + pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + self.db.get_invite_count(room_id, user_id) + } - #[tracing::instrument(skip(self))] - pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { - self.db.get_left_count(room_id, user_id) - } + #[tracing::instrument(skip(self))] + pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + self.db.get_left_count(room_id, user_id) + } - /// Returns an iterator over all rooms this user joined. - #[tracing::instrument(skip(self))] - pub fn rooms_joined<'a>( - &'a self, - user_id: &UserId, - ) -> impl Iterator> + 'a { - self.db.rooms_joined(user_id) - } + /// Returns an iterator over all rooms this user joined. + #[tracing::instrument(skip(self))] + pub fn rooms_joined<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { + self.db.rooms_joined(user_id) + } - /// Returns an iterator over all rooms a user was invited to. - #[tracing::instrument(skip(self))] - pub fn rooms_invited<'a>( - &'a self, - user_id: &UserId, - ) -> impl Iterator>)>> + 'a { - self.db.rooms_invited(user_id) - } + /// Returns an iterator over all rooms a user was invited to. + #[tracing::instrument(skip(self))] + pub fn rooms_invited<'a>( + &'a self, user_id: &UserId, + ) -> impl Iterator>)>> + 'a { + self.db.rooms_invited(user_id) + } - #[tracing::instrument(skip(self))] - pub fn invite_state( - &self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result>>> { - self.db.invite_state(user_id, room_id) - } + #[tracing::instrument(skip(self))] + pub fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { + self.db.invite_state(user_id, room_id) + } - #[tracing::instrument(skip(self))] - pub fn left_state( - &self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result>>> { - self.db.left_state(user_id, room_id) - } + #[tracing::instrument(skip(self))] + pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result>>> { + self.db.left_state(user_id, room_id) + } - /// Returns an iterator over all rooms a user left. - #[tracing::instrument(skip(self))] - pub fn rooms_left<'a>( - &'a self, - user_id: &UserId, - ) -> impl Iterator>)>> + 'a { - self.db.rooms_left(user_id) - } + /// Returns an iterator over all rooms a user left. + #[tracing::instrument(skip(self))] + pub fn rooms_left<'a>( + &'a self, user_id: &UserId, + ) -> impl Iterator>)>> + 'a { + self.db.rooms_left(user_id) + } - #[tracing::instrument(skip(self))] - pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.once_joined(user_id, room_id) - } + #[tracing::instrument(skip(self))] + pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { + self.db.once_joined(user_id, room_id) + } - #[tracing::instrument(skip(self))] - pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.is_joined(user_id, room_id) - } + #[tracing::instrument(skip(self))] + pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_joined(user_id, room_id) } - #[tracing::instrument(skip(self))] - pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.is_invited(user_id, room_id) - } + #[tracing::instrument(skip(self))] + pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result { + self.db.is_invited(user_id, room_id) + } - #[tracing::instrument(skip(self))] - pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.is_left(user_id, room_id) - } + #[tracing::instrument(skip(self))] + pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } } diff --git a/src/service/rooms/state_compressor/data.rs b/src/service/rooms/state_compressor/data.rs index d221d576..eddc8716 100644 --- a/src/service/rooms/state_compressor/data.rs +++ b/src/service/rooms/state_compressor/data.rs @@ -4,12 +4,12 @@ use super::CompressedStateEvent; use crate::Result; pub struct StateDiff { - pub parent: Option, - pub added: Arc>, - pub removed: Arc>, + pub parent: Option, + pub added: Arc>, + pub removed: Arc>, } pub trait Data: Send + Sync { - fn get_statediff(&self, shortstatehash: u64) -> Result; - fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>; + fn get_statediff(&self, shortstatehash: u64) -> Result; + fn save_statediff(&self, shortstatehash: u64, diff: StateDiff) -> Result<()>; } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index 338249f6..2570d325 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -1,325 +1,276 @@ pub mod data; use std::{ - collections::HashSet, - mem::size_of, - sync::{Arc, Mutex}, + collections::HashSet, + mem::size_of, + sync::{Arc, Mutex}, }; pub use data::Data; use lru_cache::LruCache; use ruma::{EventId, RoomId}; +use self::data::StateDiff; use crate::{services, utils, Result}; -use self::data::StateDiff; - type StateInfoLruCache = Mutex< - LruCache< - u64, - Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed - )>, - >, + LruCache< + u64, + Vec<( + u64, // sstatehash + Arc>, // full state + Arc>, // added + Arc>, // removed + )>, + >, >; type ShortStateInfoResult = Result< - Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed - )>, + Vec<( + u64, // sstatehash + Arc>, // full state + Arc>, // added + Arc>, // removed + )>, >; type ParentStatesVec = Vec<( - u64, // sstatehash - Arc>, // full state - Arc>, // added - Arc>, // removed + u64, // sstatehash + Arc>, // full state + Arc>, // added + Arc>, // removed )>; -type HashSetCompressStateEvent = Result<( - u64, - Arc>, - Arc>, -)>; +type HashSetCompressStateEvent = Result<(u64, Arc>, Arc>)>; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, - pub stateinfo_cache: StateInfoLruCache, + pub stateinfo_cache: StateInfoLruCache, } pub type CompressedStateEvent = [u8; 2 * size_of::()]; impl Service { - /// Returns a stack with info on shortstatehash, full state, added diff and removed diff for the selected shortstatehash and each parent layer. - #[tracing::instrument(skip(self))] - pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { - if let Some(r) = self - .stateinfo_cache - .lock() - .unwrap() - .get_mut(&shortstatehash) - { - return Ok(r.clone()); - } + /// Returns a stack with info on shortstatehash, full state, added diff and + /// removed diff for the selected shortstatehash and each parent layer. + #[tracing::instrument(skip(self))] + pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult { + if let Some(r) = self.stateinfo_cache.lock().unwrap().get_mut(&shortstatehash) { + return Ok(r.clone()); + } - let StateDiff { - parent, - added, - removed, - } = self.db.get_statediff(shortstatehash)?; + let StateDiff { + parent, + added, + removed, + } = self.db.get_statediff(shortstatehash)?; - if let Some(parent) = parent { - let mut response = self.load_shortstatehash_info(parent)?; - let mut state = (*response.last().unwrap().1).clone(); - state.extend(added.iter().copied()); - let removed = (*removed).clone(); - for r in &removed { - state.remove(r); - } + if let Some(parent) = parent { + let mut response = self.load_shortstatehash_info(parent)?; + let mut state = (*response.last().unwrap().1).clone(); + state.extend(added.iter().copied()); + let removed = (*removed).clone(); + for r in &removed { + state.remove(r); + } - response.push((shortstatehash, Arc::new(state), added, Arc::new(removed))); + response.push((shortstatehash, Arc::new(state), added, Arc::new(removed))); - self.stateinfo_cache - .lock() - .unwrap() - .insert(shortstatehash, response.clone()); + self.stateinfo_cache.lock().unwrap().insert(shortstatehash, response.clone()); - Ok(response) - } else { - let response = vec![(shortstatehash, added.clone(), added, removed)]; - self.stateinfo_cache - .lock() - .unwrap() - .insert(shortstatehash, response.clone()); - Ok(response) - } - } + Ok(response) + } else { + let response = vec![(shortstatehash, added.clone(), added, removed)]; + self.stateinfo_cache.lock().unwrap().insert(shortstatehash, response.clone()); + Ok(response) + } + } - pub fn compress_state_event( - &self, - shortstatekey: u64, - event_id: &EventId, - ) -> Result { - let mut v = shortstatekey.to_be_bytes().to_vec(); - v.extend_from_slice( - &services() - .rooms - .short - .get_or_create_shorteventid(event_id)? - .to_be_bytes(), - ); - Ok(v.try_into().expect("we checked the size above")) - } + pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result { + let mut v = shortstatekey.to_be_bytes().to_vec(); + v.extend_from_slice(&services().rooms.short.get_or_create_shorteventid(event_id)?.to_be_bytes()); + Ok(v.try_into().expect("we checked the size above")) + } - /// Returns shortstatekey, event id - pub fn parse_compressed_state_event( - &self, - compressed_event: &CompressedStateEvent, - ) -> Result<(u64, Arc)> { - Ok(( - utils::u64_from_bytes(&compressed_event[0..size_of::()]) - .expect("bytes have right length"), - services().rooms.short.get_eventid_from_short( - utils::u64_from_bytes(&compressed_event[size_of::()..]) - .expect("bytes have right length"), - )?, - )) - } + /// Returns shortstatekey, event id + pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc)> { + Ok(( + utils::u64_from_bytes(&compressed_event[0..size_of::()]).expect("bytes have right length"), + services().rooms.short.get_eventid_from_short( + utils::u64_from_bytes(&compressed_event[size_of::()..]).expect("bytes have right length"), + )?, + )) + } - /// Creates a new shortstatehash that often is just a diff to an already existing - /// shortstatehash and therefore very efficient. - /// - /// There are multiple layers of diffs. The bottom layer 0 always contains the full state. Layer - /// 1 contains diffs to states of layer 0, layer 2 diffs to layer 1 and so on. If layer n > 0 - /// grows too big, it will be combined with layer n-1 to create a new diff on layer n-1 that's - /// based on layer n-2. If that layer is also too big, it will recursively fix above layers too. - /// - /// * `shortstatehash` - Shortstatehash of this state - /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid - /// * `statediffremoved` - Removed from base. Each vec is shortstatekey+shorteventid - /// * `diff_to_sibling` - Approximately how much the diff grows each time for this layer - /// * `parent_states` - A stack with info on shortstatehash, full state, added diff and removed diff for each parent layer - #[tracing::instrument(skip( - self, - statediffnew, - statediffremoved, - diff_to_sibling, - parent_states - ))] - pub fn save_state_from_diff( - &self, - shortstatehash: u64, - statediffnew: Arc>, - statediffremoved: Arc>, - diff_to_sibling: usize, - mut parent_states: ParentStatesVec, - ) -> Result<()> { - let diffsum = statediffnew.len() + statediffremoved.len(); + /// Creates a new shortstatehash that often is just a diff to an already + /// existing shortstatehash and therefore very efficient. + /// + /// There are multiple layers of diffs. The bottom layer 0 always contains + /// the full state. Layer 1 contains diffs to states of layer 0, layer 2 + /// diffs to layer 1 and so on. If layer n > 0 grows too big, it will be + /// combined with layer n-1 to create a new diff on layer n-1 that's + /// based on layer n-2. If that layer is also too big, it will recursively + /// fix above layers too. + /// + /// * `shortstatehash` - Shortstatehash of this state + /// * `statediffnew` - Added to base. Each vec is shortstatekey+shorteventid + /// * `statediffremoved` - Removed from base. Each vec is + /// shortstatekey+shorteventid + /// * `diff_to_sibling` - Approximately how much the diff grows each time + /// for this layer + /// * `parent_states` - A stack with info on shortstatehash, full state, + /// added diff and removed diff for each parent layer + #[tracing::instrument(skip(self, statediffnew, statediffremoved, diff_to_sibling, parent_states))] + pub fn save_state_from_diff( + &self, shortstatehash: u64, statediffnew: Arc>, + statediffremoved: Arc>, diff_to_sibling: usize, + mut parent_states: ParentStatesVec, + ) -> Result<()> { + let diffsum = statediffnew.len() + statediffremoved.len(); - if parent_states.len() > 3 { - // Number of layers - // To many layers, we have to go deeper - let parent = parent_states.pop().unwrap(); + if parent_states.len() > 3 { + // Number of layers + // To many layers, we have to go deeper + let parent = parent_states.pop().unwrap(); - let mut parent_new = (*parent.2).clone(); - let mut parent_removed = (*parent.3).clone(); + let mut parent_new = (*parent.2).clone(); + let mut parent_removed = (*parent.3).clone(); - for removed in statediffremoved.iter() { - if !parent_new.remove(removed) { - // It was not added in the parent and we removed it - parent_removed.insert(*removed); - } - // Else it was added in the parent and we removed it again. We can forget this change - } + for removed in statediffremoved.iter() { + if !parent_new.remove(removed) { + // It was not added in the parent and we removed it + parent_removed.insert(*removed); + } + // Else it was added in the parent and we removed it again. We + // can forget this change + } - for new in statediffnew.iter() { - if !parent_removed.remove(new) { - // It was not touched in the parent and we added it - parent_new.insert(*new); - } - // Else it was removed in the parent and we added it again. We can forget this change - } + for new in statediffnew.iter() { + if !parent_removed.remove(new) { + // It was not touched in the parent and we added it + parent_new.insert(*new); + } + // Else it was removed in the parent and we added it again. We + // can forget this change + } - self.save_state_from_diff( - shortstatehash, - Arc::new(parent_new), - Arc::new(parent_removed), - diffsum, - parent_states, - )?; + self.save_state_from_diff( + shortstatehash, + Arc::new(parent_new), + Arc::new(parent_removed), + diffsum, + parent_states, + )?; - return Ok(()); - } + return Ok(()); + } - if parent_states.is_empty() { - // There is no parent layer, create a new state - self.db.save_statediff( - shortstatehash, - StateDiff { - parent: None, - added: statediffnew, - removed: statediffremoved, - }, - )?; + if parent_states.is_empty() { + // There is no parent layer, create a new state + self.db.save_statediff( + shortstatehash, + StateDiff { + parent: None, + added: statediffnew, + removed: statediffremoved, + }, + )?; - return Ok(()); - }; + return Ok(()); + }; - // Else we have two options. - // 1. We add the current diff on top of the parent layer. - // 2. We replace a layer above + // Else we have two options. + // 1. We add the current diff on top of the parent layer. + // 2. We replace a layer above - let parent = parent_states.pop().unwrap(); - let parent_diff = parent.2.len() + parent.3.len(); + let parent = parent_states.pop().unwrap(); + let parent_diff = parent.2.len() + parent.3.len(); - if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { - // Diff too big, we replace above layer(s) - let mut parent_new = (*parent.2).clone(); - let mut parent_removed = (*parent.3).clone(); + if diffsum * diffsum >= 2 * diff_to_sibling * parent_diff { + // Diff too big, we replace above layer(s) + let mut parent_new = (*parent.2).clone(); + let mut parent_removed = (*parent.3).clone(); - for removed in statediffremoved.iter() { - if !parent_new.remove(removed) { - // It was not added in the parent and we removed it - parent_removed.insert(*removed); - } - // Else it was added in the parent and we removed it again. We can forget this change - } + for removed in statediffremoved.iter() { + if !parent_new.remove(removed) { + // It was not added in the parent and we removed it + parent_removed.insert(*removed); + } + // Else it was added in the parent and we removed it again. We + // can forget this change + } - for new in statediffnew.iter() { - if !parent_removed.remove(new) { - // It was not touched in the parent and we added it - parent_new.insert(*new); - } - // Else it was removed in the parent and we added it again. We can forget this change - } + for new in statediffnew.iter() { + if !parent_removed.remove(new) { + // It was not touched in the parent and we added it + parent_new.insert(*new); + } + // Else it was removed in the parent and we added it again. We + // can forget this change + } - self.save_state_from_diff( - shortstatehash, - Arc::new(parent_new), - Arc::new(parent_removed), - diffsum, - parent_states, - )?; - } else { - // Diff small enough, we add diff as layer on top of parent - self.db.save_statediff( - shortstatehash, - StateDiff { - parent: Some(parent.0), - added: statediffnew, - removed: statediffremoved, - }, - )?; - } + self.save_state_from_diff( + shortstatehash, + Arc::new(parent_new), + Arc::new(parent_removed), + diffsum, + parent_states, + )?; + } else { + // Diff small enough, we add diff as layer on top of parent + self.db.save_statediff( + shortstatehash, + StateDiff { + parent: Some(parent.0), + added: statediffnew, + removed: statediffremoved, + }, + )?; + } - Ok(()) - } + Ok(()) + } - /// Returns the new shortstatehash, and the state diff from the previous room state - pub fn save_state( - &self, - room_id: &RoomId, - new_state_ids_compressed: Arc>, - ) -> HashSetCompressStateEvent { - let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; + /// Returns the new shortstatehash, and the state diff from the previous + /// room state + pub fn save_state( + &self, room_id: &RoomId, new_state_ids_compressed: Arc>, + ) -> HashSetCompressStateEvent { + let previous_shortstatehash = services().rooms.state.get_room_shortstatehash(room_id)?; - let state_hash = utils::calculate_hash( - &new_state_ids_compressed - .iter() - .map(|bytes| &bytes[..]) - .collect::>(), - ); + let state_hash = + utils::calculate_hash(&new_state_ids_compressed.iter().map(|bytes| &bytes[..]).collect::>()); - let (new_shortstatehash, already_existed) = services() - .rooms - .short - .get_or_create_shortstatehash(&state_hash)?; + let (new_shortstatehash, already_existed) = services().rooms.short.get_or_create_shortstatehash(&state_hash)?; - if Some(new_shortstatehash) == previous_shortstatehash { - return Ok(( - new_shortstatehash, - Arc::new(HashSet::new()), - Arc::new(HashSet::new()), - )); - } + if Some(new_shortstatehash) == previous_shortstatehash { + return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new()))); + } - let states_parents = previous_shortstatehash - .map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; + let states_parents = + previous_shortstatehash.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?; - let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() - { - let statediffnew: HashSet<_> = new_state_ids_compressed - .difference(&parent_stateinfo.1) - .copied() - .collect(); + let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() { + let statediffnew: HashSet<_> = new_state_ids_compressed.difference(&parent_stateinfo.1).copied().collect(); - let statediffremoved: HashSet<_> = parent_stateinfo - .1 - .difference(&new_state_ids_compressed) - .copied() - .collect(); + let statediffremoved: HashSet<_> = + parent_stateinfo.1.difference(&new_state_ids_compressed).copied().collect(); - (Arc::new(statediffnew), Arc::new(statediffremoved)) - } else { - (new_state_ids_compressed, Arc::new(HashSet::new())) - }; + (Arc::new(statediffnew), Arc::new(statediffremoved)) + } else { + (new_state_ids_compressed, Arc::new(HashSet::new())) + }; - if !already_existed { - self.save_state_from_diff( - new_shortstatehash, - statediffnew.clone(), - statediffremoved.clone(), - 2, // every state change is 2 event changes on average - states_parents, - )?; - }; + if !already_existed { + self.save_state_from_diff( + new_shortstatehash, + statediffnew.clone(), + statediffremoved.clone(), + 2, // every state change is 2 event changes on average + states_parents, + )?; + }; - Ok((new_shortstatehash, statediffnew, statediffremoved)) - } + Ok((new_shortstatehash, statediffnew, statediffremoved)) + } } diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index 2f062e23..b18f4b79 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,17 +1,14 @@ -use crate::{PduEvent, Result}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; +use crate::{PduEvent, Result}; + type PduEventIterResult<'a> = Result> + 'a>>; pub trait Data: Send + Sync { - fn threads_until<'a>( - &'a self, - user_id: &'a UserId, - room_id: &'a RoomId, - until: u64, - include: &'a IncludeThreads, - ) -> PduEventIterResult<'a>; + fn threads_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, + ) -> PduEventIterResult<'a>; - fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()>; - fn get_participants(&self, root_id: &[u8]) -> Result>>; + fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()>; + fn get_participants(&self, root_id: &[u8]) -> Result>>; } diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index ad70f2e0..fd9e8b93 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -4,115 +4,92 @@ use std::collections::BTreeMap; pub use data::Data; use ruma::{ - api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, - events::relation::BundledThread, - uint, CanonicalJsonValue, EventId, RoomId, UserId, + api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads}, + events::relation::BundledThread, + uint, CanonicalJsonValue, EventId, RoomId, UserId, }; - use serde_json::json; use crate::{services, Error, PduEvent, Result}; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - pub fn threads_until<'a>( - &'a self, - user_id: &'a UserId, - room_id: &'a RoomId, - until: u64, - include: &'a IncludeThreads, - ) -> Result> + 'a> { - self.db.threads_until(user_id, room_id, until, include) - } + pub fn threads_until<'a>( + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, + ) -> Result> + 'a> { + self.db.threads_until(user_id, room_id, until, include) + } - pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { - let root_id = &services() - .rooms - .timeline - .get_pdu_id(root_event_id)? - .ok_or_else(|| { - Error::BadRequest( - ErrorKind::InvalidParam, - "Invalid event id in thread message", - ) - })?; + pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> { + let root_id = &services() + .rooms + .timeline + .get_pdu_id(root_event_id)? + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?; - let root_pdu = services() - .rooms - .timeline - .get_pdu_from_id(root_id)? - .ok_or_else(|| { - Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found") - })?; + let root_pdu = services() + .rooms + .timeline + .get_pdu_from_id(root_id)? + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; - let mut root_pdu_json = services() - .rooms - .timeline - .get_pdu_json_from_id(root_id)? - .ok_or_else(|| { - Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found") - })?; + let mut root_pdu_json = services() + .rooms + .timeline + .get_pdu_json_from_id(root_id)? + .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?; - if let CanonicalJsonValue::Object(unsigned) = root_pdu_json - .entry("unsigned".to_owned()) - .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) - { - if let Some(mut relations) = unsigned - .get("m.relations") - .and_then(|r| r.as_object()) - .and_then(|r| r.get("m.thread")) - .and_then(|relations| { - serde_json::from_value::(relations.clone().into()).ok() - }) - { - // Thread already existed - relations.count += uint!(1); - relations.latest_event = pdu.to_message_like_event(); + if let CanonicalJsonValue::Object(unsigned) = root_pdu_json + .entry("unsigned".to_owned()) + .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) + { + if let Some(mut relations) = unsigned + .get("m.relations") + .and_then(|r| r.as_object()) + .and_then(|r| r.get("m.thread")) + .and_then(|relations| serde_json::from_value::(relations.clone().into()).ok()) + { + // Thread already existed + relations.count += uint!(1); + relations.latest_event = pdu.to_message_like_event(); - let content = serde_json::to_value(relations).expect("to_value always works"); + let content = serde_json::to_value(relations).expect("to_value always works"); - unsigned.insert( - "m.relations".to_owned(), - json!({ "m.thread": content }) - .try_into() - .expect("thread is valid json"), - ); - } else { - // New thread - let relations = BundledThread { - latest_event: pdu.to_message_like_event(), - count: uint!(1), - current_user_participated: true, - }; + unsigned.insert( + "m.relations".to_owned(), + json!({ "m.thread": content }).try_into().expect("thread is valid json"), + ); + } else { + // New thread + let relations = BundledThread { + latest_event: pdu.to_message_like_event(), + count: uint!(1), + current_user_participated: true, + }; - let content = serde_json::to_value(relations).expect("to_value always works"); + let content = serde_json::to_value(relations).expect("to_value always works"); - unsigned.insert( - "m.relations".to_owned(), - json!({ "m.thread": content }) - .try_into() - .expect("thread is valid json"), - ); - } + unsigned.insert( + "m.relations".to_owned(), + json!({ "m.thread": content }).try_into().expect("thread is valid json"), + ); + } - services() - .rooms - .timeline - .replace_pdu(root_id, &root_pdu_json, &root_pdu)?; - } + services().rooms.timeline.replace_pdu(root_id, &root_pdu_json, &root_pdu)?; + } - let mut users = Vec::new(); - if let Some(userids) = self.db.get_participants(root_id)? { - users.extend_from_slice(&userids); - users.push(pdu.sender.clone()); - } else { - users.push(root_pdu.sender); - users.push(pdu.sender.clone()); - } + let mut users = Vec::new(); + if let Some(userids) = self.db.get_participants(root_id)? { + users.extend_from_slice(&userids); + users.push(pdu.sender.clone()); + } else { + users.push(root_pdu.sender); + users.push(pdu.sender.clone()); + } - self.db.update_participants(root_id, &users) - } + self.db.update_participants(root_id, &users) + } } diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 6290b8cc..a036b455 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -2,92 +2,67 @@ use std::sync::Arc; use ruma::{CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId}; +use super::PduCount; use crate::{PduEvent, Result}; -use super::PduCount; - pub trait Data: Send + Sync { - fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; + fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result; - /// Returns the `count` of this pdu's id. - fn get_pdu_count(&self, event_id: &EventId) -> Result>; + /// Returns the `count` of this pdu's id. + fn get_pdu_count(&self, event_id: &EventId) -> Result>; - /// Returns the json of a pdu. - fn get_pdu_json(&self, event_id: &EventId) -> Result>; + /// Returns the json of a pdu. + fn get_pdu_json(&self, event_id: &EventId) -> Result>; - /// Returns the json of a pdu. - fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result>; + /// Returns the json of a pdu. + fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result>; - /// Returns the pdu's id. - fn get_pdu_id(&self, event_id: &EventId) -> Result>>; + /// Returns the pdu's id. + fn get_pdu_id(&self, event_id: &EventId) -> Result>>; - /// Returns the pdu. - /// - /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result>; + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result>; - /// Returns the pdu. - /// - /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - fn get_pdu(&self, event_id: &EventId) -> Result>>; + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + fn get_pdu(&self, event_id: &EventId) -> Result>>; - /// Returns the pdu. - /// - /// This does __NOT__ check the outliers `Tree`. - fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result>; + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result>; - /// Returns the pdu as a `BTreeMap`. - fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result>; + /// Returns the pdu as a `BTreeMap`. + fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result>; - /// Adds a new pdu to the timeline - fn append_pdu( - &self, - pdu_id: &[u8], - pdu: &PduEvent, - json: &CanonicalJsonObject, - count: u64, - ) -> Result<()>; + /// Adds a new pdu to the timeline + fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()>; - // Adds a new pdu to the backfilled timeline - fn prepend_backfill_pdu( - &self, - pdu_id: &[u8], - event_id: &EventId, - json: &CanonicalJsonObject, - ) -> Result<()>; + // Adds a new pdu to the backfilled timeline + fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()>; - /// Removes a pdu and creates a new one with the same id. - fn replace_pdu( - &self, - pdu_id: &[u8], - pdu_json: &CanonicalJsonObject, - pdu: &PduEvent, - ) -> Result<()>; + /// Removes a pdu and creates a new one with the same id. + fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()>; - /// Returns an iterator over all events and their tokens in a room that happened before the - /// event with id `until` in reverse-chronological order. - #[allow(clippy::type_complexity)] - fn pdus_until<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - until: PduCount, - ) -> Result> + 'a>>; + /// Returns an iterator over all events and their tokens in a room that + /// happened before the event with id `until` in reverse-chronological + /// order. + #[allow(clippy::type_complexity)] + fn pdus_until<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, + ) -> Result> + 'a>>; - /// Returns an iterator over all events in a room that happened after the event with id `from` - /// in chronological order. - #[allow(clippy::type_complexity)] - fn pdus_after<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - from: PduCount, - ) -> Result> + 'a>>; + /// Returns an iterator over all events in a room that happened after the + /// event with id `from` in chronological order. + #[allow(clippy::type_complexity)] + fn pdus_after<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, + ) -> Result> + 'a>>; - fn increment_notification_counts( - &self, - room_id: &RoomId, - notifies: Vec, - highlights: Vec, - ) -> Result<()>; + fn increment_notification_counts( + &self, room_id: &RoomId, notifies: Vec, highlights: Vec, + ) -> Result<()>; } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index eb29d643..0a913e14 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,1308 +1,1066 @@ pub(crate) mod data; use std::{ - cmp::Ordering, - collections::{BTreeMap, HashMap}, -}; - -use std::{ - collections::HashSet, - sync::{Arc, Mutex, RwLock}, + cmp::Ordering, + collections::{BTreeMap, HashMap, HashSet}, + sync::{Arc, Mutex, RwLock}, }; pub use data::Data; use regex::Regex; use ruma::{ - api::{client::error::ErrorKind, federation}, - canonical_json::to_canonical_value, - events::{ - push_rules::PushRulesEvent, - room::{ - create::RoomCreateEventContent, - encrypted::Relation, - member::{MembershipState, RoomMemberEventContent}, - power_levels::RoomPowerLevelsEventContent, - }, - GlobalAccountDataEventType, StateEventType, TimelineEventType, - }, - push::{Action, Ruleset, Tweak}, - serde::Base64, - state_res, - state_res::{Event, RoomVersion}, - uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, - OwnedServerName, RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, + api::{client::error::ErrorKind, federation}, + canonical_json::to_canonical_value, + events::{ + push_rules::PushRulesEvent, + room::{ + create::RoomCreateEventContent, + encrypted::Relation, + member::{MembershipState, RoomMemberEventContent}, + power_levels::RoomPowerLevelsEventContent, + }, + GlobalAccountDataEventType, StateEventType, TimelineEventType, + }, + push::{Action, Ruleset, Tweak}, + serde::Base64, + state_res, + state_res::{Event, RoomVersion}, + uint, user_id, CanonicalJsonObject, CanonicalJsonValue, EventId, OwnedEventId, OwnedRoomId, OwnedServerName, + RoomAliasId, RoomId, RoomVersionId, ServerName, UserId, }; use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use tokio::sync::MutexGuard; use tracing::{error, info, warn}; -use crate::{ - api::server_server, - service::pdu::{EventHash, PduBuilder}, - services, utils, Error, PduEvent, Result, -}; - use super::state_compressor::CompressedStateEvent; +use crate::{ + api::server_server, + service::pdu::{EventHash, PduBuilder}, + services, utils, Error, PduEvent, Result, +}; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] pub enum PduCount { - Backfilled(u64), - Normal(u64), + Backfilled(u64), + Normal(u64), } impl PduCount { - pub fn min() -> Self { - Self::Backfilled(u64::MAX) - } - pub fn max() -> Self { - Self::Normal(u64::MAX) - } + pub fn min() -> Self { Self::Backfilled(u64::MAX) } - pub fn try_from_string(token: &str) -> Result { - if let Some(stripped_token) = token.strip_prefix('-') { - stripped_token.parse().map(PduCount::Backfilled) - } else { - token.parse().map(PduCount::Normal) - } - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token.")) - } + pub fn max() -> Self { Self::Normal(u64::MAX) } - pub fn stringify(&self) -> String { - match self { - PduCount::Backfilled(x) => format!("-{x}"), - PduCount::Normal(x) => x.to_string(), - } - } + pub fn try_from_string(token: &str) -> Result { + if let Some(stripped_token) = token.strip_prefix('-') { + stripped_token.parse().map(PduCount::Backfilled) + } else { + token.parse().map(PduCount::Normal) + } + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token.")) + } + + pub fn stringify(&self) -> String { + match self { + PduCount::Backfilled(x) => format!("-{x}"), + PduCount::Normal(x) => x.to_string(), + } + } } impl PartialOrd for PduCount { - fn partial_cmp(&self, other: &Self) -> Option { - Some(self.cmp(other)) - } + fn partial_cmp(&self, other: &Self) -> Option { Some(self.cmp(other)) } } impl Ord for PduCount { - fn cmp(&self, other: &Self) -> Ordering { - match (self, other) { - (PduCount::Normal(s), PduCount::Normal(o)) => s.cmp(o), - (PduCount::Backfilled(s), PduCount::Backfilled(o)) => o.cmp(s), - (PduCount::Normal(_), PduCount::Backfilled(_)) => Ordering::Greater, - (PduCount::Backfilled(_), PduCount::Normal(_)) => Ordering::Less, - } - } + fn cmp(&self, other: &Self) -> Ordering { + match (self, other) { + (PduCount::Normal(s), PduCount::Normal(o)) => s.cmp(o), + (PduCount::Backfilled(s), PduCount::Backfilled(o)) => o.cmp(s), + (PduCount::Normal(_), PduCount::Backfilled(_)) => Ordering::Greater, + (PduCount::Backfilled(_), PduCount::Normal(_)) => Ordering::Less, + } + } } pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, - pub lasttimelinecount_cache: Mutex>, + pub lasttimelinecount_cache: Mutex>, } impl Service { - #[tracing::instrument(skip(self))] - pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { - self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? - .next() - .map(|o| o.map(|(_, p)| Arc::new(p))) - .transpose() - } - - #[tracing::instrument(skip(self))] - pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { - self.db.last_timeline_count(sender_user, room_id) - } - - /// Returns the `count` of this pdu's id. - pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { - self.db.get_pdu_count(event_id) - } - - // TODO Is this the same as the function above? - /* - #[tracing::instrument(skip(self))] - pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { - let prefix = self - .get_shortroomid(room_id)? - .expect("room exists") - .to_be_bytes() - .to_vec(); - - let mut last_possible_key = prefix.clone(); - last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.pduid_pdu - .iter_from(&last_possible_key, true) - .take_while(move |(k, _)| k.starts_with(&prefix)) - .next() - .map(|b| self.pdu_count(&b.0)) - .transpose() - .map(|op| op.unwrap_or_default()) - } - */ - - /// Returns the version of a room, if known - pub fn get_room_version(&self, room_id: &RoomId) -> Result> { - let create_event = services().rooms.state_accessor.room_state_get( - room_id, - &StateEventType::RoomCreate, - "", - )?; - - let create_event_content: Option = create_event - .as_ref() - .map(|create_event| { - serde_json::from_str(create_event.content.get()).map_err(|e| { - warn!("Invalid create event: {}", e); - Error::bad_database("Invalid create event in db.") - }) - }) - .transpose()?; - - Ok(create_event_content.map(|content| content.room_version)) - } - - /// Returns the json of a pdu. - pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { - self.db.get_pdu_json(event_id) - } - - /// Returns the json of a pdu. - pub fn get_non_outlier_pdu_json( - &self, - event_id: &EventId, - ) -> Result> { - self.db.get_non_outlier_pdu_json(event_id) - } - - /// Returns the pdu's id. - pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { - self.db.get_pdu_id(event_id) - } - - /// Returns the pdu. - /// - /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { - self.db.get_non_outlier_pdu(event_id) - } - - /// Returns the pdu. - /// - /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. - pub fn get_pdu(&self, event_id: &EventId) -> Result>> { - self.db.get_pdu(event_id) - } - - /// Returns the pdu. - /// - /// This does __NOT__ check the outliers `Tree`. - pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { - self.db.get_pdu_from_id(pdu_id) - } - - /// Returns the pdu as a `BTreeMap`. - pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { - self.db.get_pdu_json_from_id(pdu_id) - } - - /// Removes a pdu and creates a new one with the same id. - #[tracing::instrument(skip(self))] - pub fn replace_pdu( - &self, - pdu_id: &[u8], - pdu_json: &CanonicalJsonObject, - pdu: &PduEvent, - ) -> Result<()> { - self.db.replace_pdu(pdu_id, pdu_json, pdu) - } - - /// Creates a new persisted data unit and adds it to a room. - /// - /// By this point the incoming event should be fully authenticated, no auth happens - /// in `append_pdu`. - /// - /// Returns pdu id - #[tracing::instrument(skip(self, pdu, pdu_json, leaves))] - pub async fn append_pdu( - &self, - pdu: &PduEvent, - mut pdu_json: CanonicalJsonObject, - leaves: Vec, - state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result> { - let shortroomid = services() - .rooms - .short - .get_shortroomid(&pdu.room_id)? - .expect("room exists"); - - // Make unsigned fields correct. This is not properly documented in the spec, but state - // events need to have previous content in the unsigned field, so clients can easily - // interpret things like membership changes - if let Some(state_key) = &pdu.state_key { - if let CanonicalJsonValue::Object(unsigned) = pdu_json - .entry("unsigned".to_owned()) - .or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) - { - if let Some(shortstatehash) = services() - .rooms - .state_accessor - .pdu_shortstatehash(&pdu.event_id) - .unwrap() - { - if let Some(prev_state) = services() - .rooms - .state_accessor - .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) - .unwrap() - { - unsigned.insert( - "prev_content".to_owned(), - CanonicalJsonValue::Object( - utils::to_canonical_object(prev_state.content.clone()).map_err( - |e| { - error!( - "Failed to convert prev_state to canonical JSON: {}", - e - ); - Error::bad_database( - "Failed to convert prev_state to canonical JSON.", - ) - }, - )?, - ), - ); - } - } - } else { - error!("Invalid unsigned type in pdu."); - } - } - - // We must keep track of all events that have been referenced. - services() - .rooms - .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services() - .rooms - .state - .set_forward_extremities(&pdu.room_id, leaves, state_lock)?; - - let mutex_insert = Arc::clone( - services() - .globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(pdu.room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().await; - - let count1 = services().globals.next_count()?; - // Mark as read first so the sending client doesn't get a notification even if appending - // fails - services() - .rooms - .edus - .read_receipt - .private_read_set(&pdu.room_id, &pdu.sender, count1)?; - services() - .rooms - .user - .reset_notification_counts(&pdu.sender, &pdu.room_id)?; - - let count2 = services().globals.next_count()?; - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&count2.to_be_bytes()); - - // https://spec.matrix.org/v1.9/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property - // For backwards-compatibility with older clients, - // servers should add a redacts property to the top level of m.room.redaction events in when serving such events over the Client-Server API. - if pdu.kind == TimelineEventType::RoomRedaction - && services().rooms.state.get_room_version(&pdu.room_id)? == RoomVersionId::V11 - { - #[derive(Deserialize)] - struct Redaction { - redacts: Option, - } - - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; - - if let Some(redact_id) = &content.redacts { - pdu_json.insert( - "redacts".to_owned(), - CanonicalJsonValue::String(redact_id.to_string()), - ); - } - } - - // Insert pdu - self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2)?; - - drop(insert_lock); - - // See if the event matches any known pushers - let power_levels: RoomPowerLevelsEventContent = services() - .rooms - .state_accessor - .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - - let sync_pdu = pdu.to_sync_room_event(); - - let mut notifies = Vec::new(); - let mut highlights = Vec::new(); - - let mut push_target = services() - .rooms - .state_cache - .get_our_real_users(&pdu.room_id)?; - - if pdu.kind == TimelineEventType::RoomMember { - if let Some(state_key) = &pdu.state_key { - let target_user_id = UserId::parse(state_key.clone()) - .expect("This state_key was previously validated"); - - if !push_target.contains(&target_user_id) { - let mut target = push_target.as_ref().clone(); - target.insert(target_user_id); - push_target = Arc::new(target); - } - } - } - - for user in push_target.iter() { - // Don't notify the user of their own events - if user == &pdu.sender { - continue; - } - - let rules_for_user = services() - .account_data - .get( - None, - user, - GlobalAccountDataEventType::PushRules.to_string().into(), - )? - .map(|event| { - serde_json::from_str::(event.get()).map_err(|e| { - warn!("Invalid push rules event in db for user ID {user}: {e}"); - Error::bad_database("Invalid push rules event in db.") - }) - }) - .transpose()? - .map_or_else( - || Ruleset::server_default(user), - |ev: PushRulesEvent| ev.content.global, - ); - - let mut highlight = false; - let mut notify = false; - - for action in services().pusher.get_actions( - user, - &rules_for_user, - &power_levels, - &sync_pdu, - &pdu.room_id, - )? { - match action { - Action::Notify => notify = true, - Action::SetTweak(Tweak::Highlight(true)) => { - highlight = true; - } - _ => {} - }; - } - - if notify { - notifies.push(user.clone()); - } - - if highlight { - highlights.push(user.clone()); - } - - for push_key in services().pusher.get_pushkeys(user) { - services().sending.send_push_pdu(&pdu_id, user, push_key?)?; - } - } - - self.db - .increment_notification_counts(&pdu.room_id, notifies, highlights)?; - - match pdu.kind { - TimelineEventType::RoomRedaction => { - let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; - match room_version_id { - RoomVersionId::V1 - | RoomVersionId::V2 - | RoomVersionId::V3 - | RoomVersionId::V4 - | RoomVersionId::V5 - | RoomVersionId::V6 - | RoomVersionId::V7 - | RoomVersionId::V8 - | RoomVersionId::V9 - | RoomVersionId::V10 => { - if let Some(redact_id) = &pdu.redacts { - self.redact_pdu(redact_id, pdu)?; - } - } - RoomVersionId::V11 => { - #[derive(Deserialize)] - struct Redaction { - redacts: Option, - } - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|e| { - warn!("Invalid content in redaction pdu: {e}"); - Error::bad_database("Invalid content in redaction pdu.") - })?; - if let Some(redact_id) = &content.redacts { - self.redact_pdu(redact_id, pdu)?; - } - } - _ => { - warn!("Unexpected or unsupported room version {}", room_version_id); - return Err(Error::BadRequest( - ErrorKind::BadJson, - "Unexpected or unsupported room version found", - )); - } - }; - } - TimelineEventType::SpaceChild => { - if let Some(_state_key) = &pdu.state_key { - services() - .rooms - .spaces - .roomid_spacechunk_cache - .lock() - .unwrap() - .remove(&pdu.room_id); - } - } - TimelineEventType::RoomMember => { - if let Some(state_key) = &pdu.state_key { - // if the state_key fails - let target_user_id = UserId::parse(state_key.clone()) - .expect("This state_key was previously validated"); - - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|e| { - error!("Invalid room member event content in pdu: {e}"); - Error::bad_database("Invalid room member event content in pdu.") - })?; - - let invite_state = match content.membership { - MembershipState::Invite => { - let state = services().rooms.state.calculate_invite_state(pdu)?; - Some(state) - } - _ => None, - }; - - // Update our membership info, we do this here incase a user is invited - // and immediately leaves we need the DB to record the invite event for auth - services() - .rooms - .state_cache - .update_membership( - &pdu.room_id, - &target_user_id, - content, - &pdu.sender, - invite_state, - true, - ) - .await?; - } - } - TimelineEventType::RoomMessage => { - #[derive(Deserialize)] - struct ExtractBody { - body: Option, - } - - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - - if let Some(body) = content.body { - services() - .rooms - .search - .index_pdu(shortroomid, &pdu_id, &body)?; - - let admin_room = services().rooms.alias.resolve_local_alias( - <&RoomAliasId>::try_from( - format!("#admins:{}", services().globals.server_name()).as_str(), - ) - .expect("#admins:server_name is a valid room alias"), - )?; - let server_user = format!("@conduit:{}", services().globals.server_name()); - - let to_conduit = body.starts_with(&format!("{server_user}: ")) - || body.starts_with(&format!("{server_user} ")) - || body.starts_with("!admin") - || body == format!("{server_user}:") - || body == server_user; - - // This will evaluate to false if the emergency password is set up so that - // the administrator can execute commands as conduit - let from_conduit = pdu.sender == server_user - && services().globals.emergency_password().is_none(); - - if to_conduit && !from_conduit && admin_room.as_ref() == Some(&pdu.room_id) { - services().admin.process_message(body, pdu.event_id.clone()); - } - } - } - _ => {} - } - - // Update Relationships - #[derive(Deserialize)] - struct ExtractRelatesTo { - #[serde(rename = "m.relates_to")] - relates_to: Relation, - } - - #[derive(Clone, Debug, Deserialize)] - struct ExtractEventId { - event_id: OwnedEventId, - } - #[derive(Clone, Debug, Deserialize)] - struct ExtractRelatesToEventId { - #[serde(rename = "m.relates_to")] - relates_to: ExtractEventId, - } - - if let Ok(content) = serde_json::from_str::(pdu.content.get()) { - if let Some(related_pducount) = services() - .rooms - .timeline - .get_pdu_count(&content.relates_to.event_id)? - { - services() - .rooms - .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; - } - } - - if let Ok(content) = serde_json::from_str::(pdu.content.get()) { - match content.relates_to { - Relation::Reply { in_reply_to } => { - // We need to do it again here, because replies don't have - // event_id as a top level field - if let Some(related_pducount) = services() - .rooms - .timeline - .get_pdu_count(&in_reply_to.event_id)? - { - services() - .rooms - .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount)?; - } - } - Relation::Thread(thread) => { - services() - .rooms - .threads - .add_to_thread(&thread.event_id, pdu)?; - } - _ => {} // TODO: Aggregate other types - } - } - - for appservice in services().appservice.all()? { - if services() - .rooms - .state_cache - .appservice_in_room(&pdu.room_id, &appservice)? - { - services() - .sending - .send_pdu_appservice(appservice.0, pdu_id.clone())?; - continue; - } - - // If the RoomMember event has a non-empty state_key, it is targeted at someone. - // If it is our appservice user, we send this PDU to it. - if pdu.kind == TimelineEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) - { - let appservice_uid = appservice.1.sender_localpart.as_str(); - if state_key_uid == appservice_uid { - services() - .sending - .send_pdu_appservice(appservice.0, pdu_id.clone())?; - continue; - } - } - } - - let namespaces = appservice.1.namespaces; - - // TODO: create some helper function to change from Strings to Regexes - let users = namespaces - .users - .iter() - .filter_map(|user| Regex::new(user.regex.as_str()).ok()) - .collect::>(); - let aliases = namespaces - .aliases - .iter() - .filter_map(|alias| Regex::new(alias.regex.as_str()).ok()) - .collect::>(); - let rooms = namespaces - .rooms - .iter() - .filter_map(|room| Regex::new(room.regex.as_str()).ok()) - .collect::>(); - - let matching_users = |users: &Regex| { - users.is_match(pdu.sender.as_str()) - || pdu.kind == TimelineEventType::RoomMember - && pdu - .state_key - .as_ref() - .map_or(false, |state_key| users.is_match(state_key)) - }; - let matching_aliases = |aliases: &Regex| { - services() - .rooms - .alias - .local_aliases_for_room(&pdu.room_id) - .filter_map(std::result::Result::ok) - .any(|room_alias| aliases.is_match(room_alias.as_str())) - }; - - if aliases.iter().any(matching_aliases) - || rooms - .iter() - .any(|namespace| namespace.is_match(pdu.room_id.as_str())) - || users.iter().any(matching_users) - { - services() - .sending - .send_pdu_appservice(appservice.0, pdu_id.clone())?; - } - } - - Ok(pdu_id) - } - - pub fn create_hash_and_sign_event( - &self, - pdu_builder: PduBuilder, - sender: &UserId, - room_id: &RoomId, - _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result<(PduEvent, CanonicalJsonObject)> { - let PduBuilder { - event_type, - content, - unsigned, - state_key, - redacts, - } = pdu_builder; - - let prev_events: Vec<_> = services() - .rooms - .state - .get_forward_extremities(room_id)? - .into_iter() - .take(20) - .collect(); - - // If there was no create event yet, assume we are creating a room - let room_version_id = services() - .rooms - .state - .get_room_version(room_id) - .or_else(|_| { - if event_type == TimelineEventType::RoomCreate { - #[derive(Deserialize)] - struct RoomCreate { - room_version: RoomVersionId, - } - let content = serde_json::from_str::(content.get()) - .expect("Invalid content in RoomCreate pdu."); - Ok(content.room_version) - } else { - Err(Error::InconsistentRoomState( - "non-create event for room of unknown version", - room_id.to_owned(), - )) - } - })?; - - let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); - - let auth_events = services().rooms.state.get_auth_events( - room_id, - &event_type, - sender, - state_key.as_deref(), - &content, - )?; - - // Our depth is the maximum depth of prev_events + 1 - let depth = prev_events - .iter() - .filter_map(|event_id| Some(services().rooms.timeline.get_pdu(event_id).ok()??.depth)) - .max() - .unwrap_or_else(|| uint!(0)) - + uint!(1); - - let mut unsigned = unsigned.unwrap_or_default(); - - if let Some(state_key) = &state_key { - if let Some(prev_pdu) = services().rooms.state_accessor.room_state_get( - room_id, - &event_type.to_string().into(), - state_key, - )? { - unsigned.insert( - "prev_content".to_owned(), - serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), - ); - unsigned.insert( - "prev_sender".to_owned(), - serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), - ); - } - } - - let mut pdu = PduEvent { - event_id: ruma::event_id!("$thiswillbefilledinlater").into(), - room_id: room_id.to_owned(), - sender: sender.to_owned(), - origin_server_ts: utils::millis_since_unix_epoch() - .try_into() - .expect("time is valid"), - kind: event_type, - content, - state_key, - prev_events, - depth, - auth_events: auth_events - .values() - .map(|pdu| pdu.event_id.clone()) - .collect(), - redacts, - unsigned: if unsigned.is_empty() { - None - } else { - Some(to_raw_value(&unsigned).expect("to_raw_value always works")) - }, - hashes: EventHash { - sha256: "aaa".to_owned(), - }, - signatures: None, - }; - - let auth_check = state_res::auth_check( - &room_version, - &pdu, - None::, // TODO: third_party_invite - |k, s| auth_events.get(&(k.clone(), s.to_owned())), - ) - .map_err(|e| { - error!("Auth check failed: {:?}", e); - Error::bad_database("Auth check failed.") - })?; - - if !auth_check { - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Event is not authorized.", - )); - } - - // Hash and sign - let mut pdu_json = utils::to_canonical_object(&pdu).map_err(|e| { - error!("Failed to convert PDU to canonical JSON: {}", e); - Error::bad_database("Failed to convert PDU to canonical JSON.") - })?; - - pdu_json.remove("event_id"); - - // Add origin because synapse likes that (and it's required in the spec) - pdu_json.insert( - "origin".to_owned(), - to_canonical_value(services().globals.server_name()) - .expect("server name is a valid CanonicalJsonValue"), - ); - - match ruma::signatures::hash_and_sign_event( - services().globals.server_name().as_str(), - services().globals.keypair(), - &mut pdu_json, - &room_version_id, - ) { - Ok(_) => {} - Err(e) => { - return match e { - ruma::signatures::Error::PduSize => Err(Error::BadRequest( - ErrorKind::TooLarge, - "Message is too long", - )), - _ => Err(Error::BadRequest( - ErrorKind::Unknown, - "Signing event failed", - )), - } - } - } - - // Generate event id - pdu.event_id = EventId::parse_arc(format!( - "${}", - ruma::signatures::reference_hash(&pdu_json, &room_version_id) - .expect("ruma can calculate reference hashes") - )) - .expect("ruma's reference hashes are valid event ids"); - - pdu_json.insert( - "event_id".to_owned(), - CanonicalJsonValue::String(pdu.event_id.as_str().to_owned()), - ); - - // Generate short event id - let _shorteventid = services() - .rooms - .short - .get_or_create_shorteventid(&pdu.event_id)?; - - Ok((pdu, pdu_json)) - } - - /// Creates a new persisted data unit and adds it to a room. This function takes a - /// roomid_mutex_state, meaning that only this function is able to mutate the room state. - #[tracing::instrument(skip(self, state_lock))] - pub async fn build_and_append_pdu( - &self, - pdu_builder: PduBuilder, - sender: &UserId, - room_id: &RoomId, - state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result> { - let (pdu, pdu_json) = - self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; - - let admin_room = services().rooms.alias.resolve_local_alias( - <&RoomAliasId>::try_from( - format!("#admins:{}", services().globals.server_name()).as_str(), - ) - .expect("#admins:server_name is a valid room alias"), - )?; - if admin_room.filter(|v| v == room_id).is_some() { - match pdu.event_type() { - TimelineEventType::RoomEncryption => { - warn!("Encryption is not allowed in the admins room"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Encryption is not allowed in the admins room.", - )); - } - TimelineEventType::RoomMember => { - #[derive(Deserialize)] - struct ExtractMembership { - membership: MembershipState, - } - - let target = pdu - .state_key() - .filter(|v| v.starts_with('@')) - .unwrap_or(sender.as_str()); - let server_name = services().globals.server_name(); - let server_user = format!("@conduit:{server_name}"); - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - - if content.membership == MembershipState::Leave { - if target == server_user { - warn!("Conduit user cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Conduit user cannot leave from admins room.", - )); - } - - let count = services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(std::result::Result::ok) - .filter(|m| m.server_name() == server_name) - .filter(|m| m != target) - .count(); - if count < 2 { - warn!("Last admin cannot leave from admins room"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Last admin cannot leave from admins room.", - )); - } - } - - if content.membership == MembershipState::Ban && pdu.state_key().is_some() { - if target == server_user { - warn!("Conduit user cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Conduit user cannot be banned in admins room.", - )); - } - - let count = services() - .rooms - .state_cache - .room_members(room_id) - .filter_map(std::result::Result::ok) - .filter(|m| m.server_name() == server_name) - .filter(|m| m != target) - .count(); - if count < 2 { - warn!("Last admin cannot be banned in admins room"); - return Err(Error::BadRequest( - ErrorKind::Forbidden, - "Last admin cannot be banned in admins room.", - )); - } - } - } - _ => {} - } - } - - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - let statehashid = services().rooms.state.append_to_state(&pdu)?; - - let pdu_id = self - .append_pdu( - &pdu, - pdu_json, - // Since this PDU references all pdu_leaves we can update the leaves - // of the room - vec![(*pdu.event_id).to_owned()], - state_lock, - ) - .await?; - - // We set the room state after inserting the pdu, so that we never have a moment in time - // where events in the current room state do not exist - services() - .rooms - .state - .set_room_state(room_id, statehashid, state_lock)?; - - let mut servers: HashSet = services() - .rooms - .state_cache - .room_servers(room_id) - .filter_map(std::result::Result::ok) - .collect(); - - // In case we are kicking or banning a user, we need to inform their server of the change - if pdu.kind == TimelineEventType::RoomMember { - if let Some(state_key_uid) = &pdu - .state_key - .as_ref() - .and_then(|state_key| UserId::parse(state_key.as_str()).ok()) - { - servers.insert(state_key_uid.server_name().to_owned()); - } - } - - // Remove our server from the server list since it will be added to it by room_servers() and/or the if statement above - servers.remove(services().globals.server_name()); - - services().sending.send_pdu(servers.into_iter(), &pdu_id)?; - - Ok(pdu.event_id) - } - - /// Append the incoming event setting the state snapshot to the state from the - /// server that sent the event. - #[tracing::instrument(skip_all)] - pub async fn append_incoming_pdu( - &self, - pdu: &PduEvent, - pdu_json: CanonicalJsonObject, - new_room_leaves: Vec, - state_ids_compressed: Arc>, - soft_fail: bool, - state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex - ) -> Result>> { - // We append to state before appending the pdu, so we don't have a moment in time with the - // pdu without it's state. This is okay because append_pdu can't fail. - services().rooms.state.set_event_state( - &pdu.event_id, - &pdu.room_id, - state_ids_compressed, - )?; - - if soft_fail { - services() - .rooms - .pdu_metadata - .mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; - services().rooms.state.set_forward_extremities( - &pdu.room_id, - new_room_leaves, - state_lock, - )?; - return Ok(None); - } - - let pdu_id = services() - .rooms - .timeline - .append_pdu(pdu, pdu_json, new_room_leaves, state_lock) - .await?; - - Ok(Some(pdu_id)) - } - - /// Returns an iterator over all PDUs in a room. - pub fn all_pdus<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - ) -> Result> + 'a> { - self.pdus_after(user_id, room_id, PduCount::min()) - } - - /// Returns an iterator over all events and their tokens in a room that happened before the - /// event with id `until` in reverse-chronological order. - #[tracing::instrument(skip(self))] - pub fn pdus_until<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - until: PduCount, - ) -> Result> + 'a> { - self.db.pdus_until(user_id, room_id, until) - } - - /// Returns an iterator over all events and their token in a room that happened after the event - /// with id `from` in chronological order. - #[tracing::instrument(skip(self))] - pub fn pdus_after<'a>( - &'a self, - user_id: &UserId, - room_id: &RoomId, - from: PduCount, - ) -> Result> + 'a> { - self.db.pdus_after(user_id, room_id, from) - } - - /// Replace a PDU with the redacted form. - #[tracing::instrument(skip(self, reason))] - pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { - // TODO: Don't reserialize, keep original json - if let Some(pdu_id) = self.get_pdu_id(event_id)? { - let mut pdu = self - .get_pdu_from_id(&pdu_id)? - .ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; - let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; - pdu.redact(room_version_id, reason)?; - self.replace_pdu( - &pdu_id, - &utils::to_canonical_object(&pdu).map_err(|e| { - error!("Failed to convert PDU to canonical JSON: {}", e); - Error::bad_database("Failed to convert PDU to canonical JSON.") - })?, - &pdu, - )?; - } - // If event does not exist, just noop - Ok(()) - } - - #[tracing::instrument(skip(self, room_id))] - pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { - let first_pdu = self - .all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? - .next() - .expect("Room is not empty")?; - - if first_pdu.0 < from { - // No backfill required, there are still events between them - return Ok(()); - } - - let power_levels: RoomPowerLevelsEventContent = services() - .rooms - .state_accessor - .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? - .map(|ev| { - serde_json::from_str(ev.content.get()) - .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) - }) - .transpose()? - .unwrap_or_default(); - let mut admin_servers = power_levels - .users - .iter() - .filter(|(_, level)| **level > power_levels.users_default) - .map(|(user_id, _)| user_id.server_name()) - .collect::>(); - admin_servers.remove(services().globals.server_name()); - - // Request backfill - for backfill_server in admin_servers { - info!("Asking {backfill_server} for backfill"); - let response = services() - .sending - .send_federation_request( - backfill_server, - federation::backfill::get_backfill::v1::Request { - room_id: room_id.to_owned(), - v: vec![first_pdu.1.event_id.as_ref().to_owned()], - limit: uint!(100), - }, - ) - .await; - match response { - Ok(response) => { - let pub_key_map = RwLock::new(BTreeMap::new()); - for pdu in response.pdus { - if let Err(e) = self.backfill_pdu(backfill_server, pdu, &pub_key_map).await - { - warn!("Failed to add backfilled pdu: {e}"); - } - } - return Ok(()); - } - Err(e) => { - warn!("{backfill_server} could not provide backfill: {e}"); - } - } - } - - info!("No servers could backfill"); - Ok(()) - } - - #[tracing::instrument(skip(self, pdu))] - pub async fn backfill_pdu( - &self, - origin: &ServerName, - pdu: Box, - pub_key_map: &RwLock>>, - ) -> Result<()> { - let (event_id, value, room_id) = server_server::parse_incoming_pdu(&pdu)?; - - // Lock so we cannot backfill the same pdu twice at the same time - let mutex = Arc::clone( - services() - .globals - .roomid_mutex_federation - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let mutex_lock = mutex.lock().await; - - // Skip the PDU if we already have it as a timeline event - if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(&event_id)? { - info!("We already know {event_id} at {pdu_id:?}"); - return Ok(()); - } - - services() - .rooms - .event_handler - .fetch_required_signing_keys([&value], pub_key_map) - .await?; - - services() - .rooms - .event_handler - .handle_incoming_pdu(origin, &event_id, &room_id, value, false, pub_key_map) - .await?; - - let value = self.get_pdu_json(&event_id)?.expect("We just created it"); - let pdu = self.get_pdu(&event_id)?.expect("We just created it"); - - let shortroomid = services() - .rooms - .short - .get_shortroomid(&room_id)? - .expect("room exists"); - - let mutex_insert = Arc::clone( - services() - .globals - .roomid_mutex_insert - .write() - .unwrap() - .entry(room_id.clone()) - .or_default(), - ); - let insert_lock = mutex_insert.lock().await; - - let count = services().globals.next_count()?; - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes()); - - // Insert pdu - self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; - - drop(insert_lock); - - if pdu.kind == TimelineEventType::RoomMessage { - #[derive(Deserialize)] - struct ExtractBody { - body: Option, - } - - let content = serde_json::from_str::(pdu.content.get()) - .map_err(|_| Error::bad_database("Invalid content in pdu."))?; - - if let Some(body) = content.body { - services() - .rooms - .search - .index_pdu(shortroomid, &pdu_id, &body)?; - } - } - drop(mutex_lock); - - info!("Prepended backfill pdu"); - Ok(()) - } + #[tracing::instrument(skip(self))] + pub fn first_pdu_in_room(&self, room_id: &RoomId) -> Result>> { + self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)? + .next() + .map(|o| o.map(|(_, p)| Arc::new(p))) + .transpose() + } + + #[tracing::instrument(skip(self))] + pub fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result { + self.db.last_timeline_count(sender_user, room_id) + } + + /// Returns the `count` of this pdu's id. + pub fn get_pdu_count(&self, event_id: &EventId) -> Result> { self.db.get_pdu_count(event_id) } + + // TODO Is this the same as the function above? + /* + #[tracing::instrument(skip(self))] + pub fn latest_pdu_count(&self, room_id: &RoomId) -> Result { + let prefix = self + .get_shortroomid(room_id)? + .expect("room exists") + .to_be_bytes() + .to_vec(); + + let mut last_possible_key = prefix.clone(); + last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes()); + + self.pduid_pdu + .iter_from(&last_possible_key, true) + .take_while(move |(k, _)| k.starts_with(&prefix)) + .next() + .map(|b| self.pdu_count(&b.0)) + .transpose() + .map(|op| op.unwrap_or_default()) + } + */ + + /// Returns the version of a room, if known + pub fn get_room_version(&self, room_id: &RoomId) -> Result> { + let create_event = services().rooms.state_accessor.room_state_get(room_id, &StateEventType::RoomCreate, "")?; + + let create_event_content: Option = create_event + .as_ref() + .map(|create_event| { + serde_json::from_str(create_event.content.get()).map_err(|e| { + warn!("Invalid create event: {}", e); + Error::bad_database("Invalid create event in db.") + }) + }) + .transpose()?; + + Ok(create_event_content.map(|content| content.room_version)) + } + + /// Returns the json of a pdu. + pub fn get_pdu_json(&self, event_id: &EventId) -> Result> { + self.db.get_pdu_json(event_id) + } + + /// Returns the json of a pdu. + pub fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result> { + self.db.get_non_outlier_pdu_json(event_id) + } + + /// Returns the pdu's id. + pub fn get_pdu_id(&self, event_id: &EventId) -> Result>> { self.db.get_pdu_id(event_id) } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result> { + self.db.get_non_outlier_pdu(event_id) + } + + /// Returns the pdu. + /// + /// Checks the `eventid_outlierpdu` Tree if not found in the timeline. + pub fn get_pdu(&self, event_id: &EventId) -> Result>> { self.db.get_pdu(event_id) } + + /// Returns the pdu. + /// + /// This does __NOT__ check the outliers `Tree`. + pub fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result> { self.db.get_pdu_from_id(pdu_id) } + + /// Returns the pdu as a `BTreeMap`. + pub fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result> { + self.db.get_pdu_json_from_id(pdu_id) + } + + /// Removes a pdu and creates a new one with the same id. + #[tracing::instrument(skip(self))] + pub fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { + self.db.replace_pdu(pdu_id, pdu_json, pdu) + } + + /// Creates a new persisted data unit and adds it to a room. + /// + /// By this point the incoming event should be fully authenticated, no auth + /// happens in `append_pdu`. + /// + /// Returns pdu id + #[tracing::instrument(skip(self, pdu, pdu_json, leaves))] + pub async fn append_pdu( + &self, + pdu: &PduEvent, + mut pdu_json: CanonicalJsonObject, + leaves: Vec, + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result> { + let shortroomid = services().rooms.short.get_shortroomid(&pdu.room_id)?.expect("room exists"); + + // Make unsigned fields correct. This is not properly documented in the spec, + // but state events need to have previous content in the unsigned field, so + // clients can easily interpret things like membership changes + if let Some(state_key) = &pdu.state_key { + if let CanonicalJsonValue::Object(unsigned) = + pdu_json.entry("unsigned".to_owned()).or_insert_with(|| CanonicalJsonValue::Object(BTreeMap::default())) + { + if let Some(shortstatehash) = services().rooms.state_accessor.pdu_shortstatehash(&pdu.event_id).unwrap() + { + if let Some(prev_state) = services() + .rooms + .state_accessor + .state_get(shortstatehash, &pdu.kind.to_string().into(), state_key) + .unwrap() + { + unsigned.insert( + "prev_content".to_owned(), + CanonicalJsonValue::Object( + utils::to_canonical_object(prev_state.content.clone()).map_err(|e| { + error!("Failed to convert prev_state to canonical JSON: {}", e); + Error::bad_database("Failed to convert prev_state to canonical JSON.") + })?, + ), + ); + } + } + } else { + error!("Invalid unsigned type in pdu."); + } + } + + // We must keep track of all events that have been referenced. + services().rooms.pdu_metadata.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + services().rooms.state.set_forward_extremities(&pdu.room_id, leaves, state_lock)?; + + let mutex_insert = + Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(pdu.room_id.clone()).or_default()); + let insert_lock = mutex_insert.lock().await; + + let count1 = services().globals.next_count()?; + // Mark as read first so the sending client doesn't get a notification even if + // appending fails + services().rooms.edus.read_receipt.private_read_set(&pdu.room_id, &pdu.sender, count1)?; + services().rooms.user.reset_notification_counts(&pdu.sender, &pdu.room_id)?; + + let count2 = services().globals.next_count()?; + let mut pdu_id = shortroomid.to_be_bytes().to_vec(); + pdu_id.extend_from_slice(&count2.to_be_bytes()); + + // https://spec.matrix.org/v1.9/rooms/v11/#moving-the-redacts-property-of-mroomredaction-events-to-a-content-property + // For backwards-compatibility with older clients, + // servers should add a redacts property to the top level of m.room.redaction + // events in when serving such events over the Client-Server API. + if pdu.kind == TimelineEventType::RoomRedaction + && services().rooms.state.get_room_version(&pdu.room_id)? == RoomVersionId::V11 + { + #[derive(Deserialize)] + struct Redaction { + redacts: Option, + } + + let content = serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in redaction pdu."))?; + + if let Some(redact_id) = &content.redacts { + pdu_json.insert("redacts".to_owned(), CanonicalJsonValue::String(redact_id.to_string())); + } + } + + // Insert pdu + self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2)?; + + drop(insert_lock); + + // See if the event matches any known pushers + let power_levels: RoomPowerLevelsEventContent = services() + .rooms + .state_accessor + .room_state_get(&pdu.room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + + let sync_pdu = pdu.to_sync_room_event(); + + let mut notifies = Vec::new(); + let mut highlights = Vec::new(); + + let mut push_target = services().rooms.state_cache.get_our_real_users(&pdu.room_id)?; + + if pdu.kind == TimelineEventType::RoomMember { + if let Some(state_key) = &pdu.state_key { + let target_user_id = UserId::parse(state_key.clone()).expect("This state_key was previously validated"); + + if !push_target.contains(&target_user_id) { + let mut target = push_target.as_ref().clone(); + target.insert(target_user_id); + push_target = Arc::new(target); + } + } + } + + for user in push_target.iter() { + // Don't notify the user of their own events + if user == &pdu.sender { + continue; + } + + let rules_for_user = services() + .account_data + .get(None, user, GlobalAccountDataEventType::PushRules.to_string().into())? + .map(|event| { + serde_json::from_str::(event.get()).map_err(|e| { + warn!("Invalid push rules event in db for user ID {user}: {e}"); + Error::bad_database("Invalid push rules event in db.") + }) + }) + .transpose()? + .map_or_else(|| Ruleset::server_default(user), |ev: PushRulesEvent| ev.content.global); + + let mut highlight = false; + let mut notify = false; + + for action in + services().pusher.get_actions(user, &rules_for_user, &power_levels, &sync_pdu, &pdu.room_id)? + { + match action { + Action::Notify => notify = true, + Action::SetTweak(Tweak::Highlight(true)) => { + highlight = true; + }, + _ => {}, + }; + } + + if notify { + notifies.push(user.clone()); + } + + if highlight { + highlights.push(user.clone()); + } + + for push_key in services().pusher.get_pushkeys(user) { + services().sending.send_push_pdu(&pdu_id, user, push_key?)?; + } + } + + self.db.increment_notification_counts(&pdu.room_id, notifies, highlights)?; + + match pdu.kind { + TimelineEventType::RoomRedaction => { + let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; + match room_version_id { + RoomVersionId::V1 + | RoomVersionId::V2 + | RoomVersionId::V3 + | RoomVersionId::V4 + | RoomVersionId::V5 + | RoomVersionId::V6 + | RoomVersionId::V7 + | RoomVersionId::V8 + | RoomVersionId::V9 + | RoomVersionId::V10 => { + if let Some(redact_id) = &pdu.redacts { + self.redact_pdu(redact_id, pdu)?; + } + }, + RoomVersionId::V11 => { + #[derive(Deserialize)] + struct Redaction { + redacts: Option, + } + let content = serde_json::from_str::(pdu.content.get()).map_err(|e| { + warn!("Invalid content in redaction pdu: {e}"); + Error::bad_database("Invalid content in redaction pdu.") + })?; + if let Some(redact_id) = &content.redacts { + self.redact_pdu(redact_id, pdu)?; + } + }, + _ => { + warn!("Unexpected or unsupported room version {}", room_version_id); + return Err(Error::BadRequest( + ErrorKind::BadJson, + "Unexpected or unsupported room version found", + )); + }, + }; + }, + TimelineEventType::SpaceChild => { + if let Some(_state_key) = &pdu.state_key { + services().rooms.spaces.roomid_spacechunk_cache.lock().unwrap().remove(&pdu.room_id); + } + }, + TimelineEventType::RoomMember => { + if let Some(state_key) = &pdu.state_key { + // if the state_key fails + let target_user_id = + UserId::parse(state_key.clone()).expect("This state_key was previously validated"); + + let content = serde_json::from_str::(pdu.content.get()).map_err(|e| { + error!("Invalid room member event content in pdu: {e}"); + Error::bad_database("Invalid room member event content in pdu.") + })?; + + let invite_state = match content.membership { + MembershipState::Invite => { + let state = services().rooms.state.calculate_invite_state(pdu)?; + Some(state) + }, + _ => None, + }; + + // Update our membership info, we do this here incase a user is invited + // and immediately leaves we need the DB to record the invite event for auth + services() + .rooms + .state_cache + .update_membership(&pdu.room_id, &target_user_id, content, &pdu.sender, invite_state, true) + .await?; + } + }, + TimelineEventType::RoomMessage => { + #[derive(Deserialize)] + struct ExtractBody { + body: Option, + } + + let content = serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + + if let Some(body) = content.body { + services().rooms.search.index_pdu(shortroomid, &pdu_id, &body)?; + + let admin_room = services().rooms.alias.resolve_local_alias( + <&RoomAliasId>::try_from(format!("#admins:{}", services().globals.server_name()).as_str()) + .expect("#admins:server_name is a valid room alias"), + )?; + let server_user = format!("@conduit:{}", services().globals.server_name()); + + let to_conduit = body.starts_with(&format!("{server_user}: ")) + || body.starts_with(&format!("{server_user} ")) + || body.starts_with("!admin") + || body == format!("{server_user}:") + || body == server_user; + + // This will evaluate to false if the emergency password is set up so that + // the administrator can execute commands as conduit + let from_conduit = pdu.sender == server_user && services().globals.emergency_password().is_none(); + + if to_conduit && !from_conduit && admin_room.as_ref() == Some(&pdu.room_id) { + services().admin.process_message(body, pdu.event_id.clone()); + } + } + }, + _ => {}, + } + + // Update Relationships + #[derive(Deserialize)] + struct ExtractRelatesTo { + #[serde(rename = "m.relates_to")] + relates_to: Relation, + } + + #[derive(Clone, Debug, Deserialize)] + struct ExtractEventId { + event_id: OwnedEventId, + } + #[derive(Clone, Debug, Deserialize)] + struct ExtractRelatesToEventId { + #[serde(rename = "m.relates_to")] + relates_to: ExtractEventId, + } + + if let Ok(content) = serde_json::from_str::(pdu.content.get()) { + if let Some(related_pducount) = services().rooms.timeline.get_pdu_count(&content.relates_to.event_id)? { + services().rooms.pdu_metadata.add_relation(PduCount::Normal(count2), related_pducount)?; + } + } + + if let Ok(content) = serde_json::from_str::(pdu.content.get()) { + match content.relates_to { + Relation::Reply { + in_reply_to, + } => { + // We need to do it again here, because replies don't have + // event_id as a top level field + if let Some(related_pducount) = services().rooms.timeline.get_pdu_count(&in_reply_to.event_id)? { + services().rooms.pdu_metadata.add_relation(PduCount::Normal(count2), related_pducount)?; + } + }, + Relation::Thread(thread) => { + services().rooms.threads.add_to_thread(&thread.event_id, pdu)?; + }, + _ => {}, // TODO: Aggregate other types + } + } + + for appservice in services().appservice.all()? { + if services().rooms.state_cache.appservice_in_room(&pdu.room_id, &appservice)? { + services().sending.send_pdu_appservice(appservice.0, pdu_id.clone())?; + continue; + } + + // If the RoomMember event has a non-empty state_key, it is targeted at someone. + // If it is our appservice user, we send this PDU to it. + if pdu.kind == TimelineEventType::RoomMember { + if let Some(state_key_uid) = + &pdu.state_key.as_ref().and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + { + let appservice_uid = appservice.1.sender_localpart.as_str(); + if state_key_uid == appservice_uid { + services().sending.send_pdu_appservice(appservice.0, pdu_id.clone())?; + continue; + } + } + } + + let namespaces = appservice.1.namespaces; + + // TODO: create some helper function to change from Strings to Regexes + let users = + namespaces.users.iter().filter_map(|user| Regex::new(user.regex.as_str()).ok()).collect::>(); + let aliases = + namespaces.aliases.iter().filter_map(|alias| Regex::new(alias.regex.as_str()).ok()).collect::>(); + let rooms = + namespaces.rooms.iter().filter_map(|room| Regex::new(room.regex.as_str()).ok()).collect::>(); + + let matching_users = |users: &Regex| { + users.is_match(pdu.sender.as_str()) + || pdu.kind == TimelineEventType::RoomMember + && pdu.state_key.as_ref().map_or(false, |state_key| users.is_match(state_key)) + }; + let matching_aliases = |aliases: &Regex| { + services() + .rooms + .alias + .local_aliases_for_room(&pdu.room_id) + .filter_map(std::result::Result::ok) + .any(|room_alias| aliases.is_match(room_alias.as_str())) + }; + + if aliases.iter().any(matching_aliases) + || rooms.iter().any(|namespace| namespace.is_match(pdu.room_id.as_str())) + || users.iter().any(matching_users) + { + services().sending.send_pdu_appservice(appservice.0, pdu_id.clone())?; + } + } + + Ok(pdu_id) + } + + pub fn create_hash_and_sign_event( + &self, + pdu_builder: PduBuilder, + sender: &UserId, + room_id: &RoomId, + _mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result<(PduEvent, CanonicalJsonObject)> { + let PduBuilder { + event_type, + content, + unsigned, + state_key, + redacts, + } = pdu_builder; + + let prev_events: Vec<_> = + services().rooms.state.get_forward_extremities(room_id)?.into_iter().take(20).collect(); + + // If there was no create event yet, assume we are creating a room + let room_version_id = services().rooms.state.get_room_version(room_id).or_else(|_| { + if event_type == TimelineEventType::RoomCreate { + #[derive(Deserialize)] + struct RoomCreate { + room_version: RoomVersionId, + } + let content = + serde_json::from_str::(content.get()).expect("Invalid content in RoomCreate pdu."); + Ok(content.room_version) + } else { + Err(Error::InconsistentRoomState( + "non-create event for room of unknown version", + room_id.to_owned(), + )) + } + })?; + + let room_version = RoomVersion::new(&room_version_id).expect("room version is supported"); + + let auth_events = + services().rooms.state.get_auth_events(room_id, &event_type, sender, state_key.as_deref(), &content)?; + + // Our depth is the maximum depth of prev_events + 1 + let depth = prev_events + .iter() + .filter_map(|event_id| Some(services().rooms.timeline.get_pdu(event_id).ok()??.depth)) + .max() + .unwrap_or_else(|| uint!(0)) + + uint!(1); + + let mut unsigned = unsigned.unwrap_or_default(); + + if let Some(state_key) = &state_key { + if let Some(prev_pdu) = + services().rooms.state_accessor.room_state_get(room_id, &event_type.to_string().into(), state_key)? + { + unsigned.insert( + "prev_content".to_owned(), + serde_json::from_str(prev_pdu.content.get()).expect("string is valid json"), + ); + unsigned.insert( + "prev_sender".to_owned(), + serde_json::to_value(&prev_pdu.sender).expect("UserId::to_value always works"), + ); + } + } + + let mut pdu = PduEvent { + event_id: ruma::event_id!("$thiswillbefilledinlater").into(), + room_id: room_id.to_owned(), + sender: sender.to_owned(), + origin_server_ts: utils::millis_since_unix_epoch().try_into().expect("time is valid"), + kind: event_type, + content, + state_key, + prev_events, + depth, + auth_events: auth_events.values().map(|pdu| pdu.event_id.clone()).collect(), + redacts, + unsigned: if unsigned.is_empty() { + None + } else { + Some(to_raw_value(&unsigned).expect("to_raw_value always works")) + }, + hashes: EventHash { + sha256: "aaa".to_owned(), + }, + signatures: None, + }; + + let auth_check = state_res::auth_check( + &room_version, + &pdu, + None::, // TODO: third_party_invite + |k, s| auth_events.get(&(k.clone(), s.to_owned())), + ) + .map_err(|e| { + error!("Auth check failed: {:?}", e); + Error::bad_database("Auth check failed.") + })?; + + if !auth_check { + return Err(Error::BadRequest(ErrorKind::Forbidden, "Event is not authorized.")); + } + + // Hash and sign + let mut pdu_json = utils::to_canonical_object(&pdu).map_err(|e| { + error!("Failed to convert PDU to canonical JSON: {}", e); + Error::bad_database("Failed to convert PDU to canonical JSON.") + })?; + + pdu_json.remove("event_id"); + + // Add origin because synapse likes that (and it's required in the spec) + pdu_json.insert( + "origin".to_owned(), + to_canonical_value(services().globals.server_name()).expect("server name is a valid CanonicalJsonValue"), + ); + + match ruma::signatures::hash_and_sign_event( + services().globals.server_name().as_str(), + services().globals.keypair(), + &mut pdu_json, + &room_version_id, + ) { + Ok(_) => {}, + Err(e) => { + return match e { + ruma::signatures::Error::PduSize => { + Err(Error::BadRequest(ErrorKind::TooLarge, "Message is too long")) + }, + _ => Err(Error::BadRequest(ErrorKind::Unknown, "Signing event failed")), + } + }, + } + + // Generate event id + pdu.event_id = EventId::parse_arc(format!( + "${}", + ruma::signatures::reference_hash(&pdu_json, &room_version_id).expect("ruma can calculate reference hashes") + )) + .expect("ruma's reference hashes are valid event ids"); + + pdu_json.insert( + "event_id".to_owned(), + CanonicalJsonValue::String(pdu.event_id.as_str().to_owned()), + ); + + // Generate short event id + let _shorteventid = services().rooms.short.get_or_create_shorteventid(&pdu.event_id)?; + + Ok((pdu, pdu_json)) + } + + /// Creates a new persisted data unit and adds it to a room. This function + /// takes a roomid_mutex_state, meaning that only this function is able to + /// mutate the room state. + #[tracing::instrument(skip(self, state_lock))] + pub async fn build_and_append_pdu( + &self, + pdu_builder: PduBuilder, + sender: &UserId, + room_id: &RoomId, + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result> { + let (pdu, pdu_json) = self.create_hash_and_sign_event(pdu_builder, sender, room_id, state_lock)?; + + let admin_room = services().rooms.alias.resolve_local_alias( + <&RoomAliasId>::try_from(format!("#admins:{}", services().globals.server_name()).as_str()) + .expect("#admins:server_name is a valid room alias"), + )?; + if admin_room.filter(|v| v == room_id).is_some() { + match pdu.event_type() { + TimelineEventType::RoomEncryption => { + warn!("Encryption is not allowed in the admins room"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Encryption is not allowed in the admins room.", + )); + }, + TimelineEventType::RoomMember => { + #[derive(Deserialize)] + struct ExtractMembership { + membership: MembershipState, + } + + let target = pdu.state_key().filter(|v| v.starts_with('@')).unwrap_or(sender.as_str()); + let server_name = services().globals.server_name(); + let server_user = format!("@conduit:{server_name}"); + let content = serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + + if content.membership == MembershipState::Leave { + if target == server_user { + warn!("Conduit user cannot leave from admins room"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Conduit user cannot leave from admins room.", + )); + } + + let count = services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(std::result::Result::ok) + .filter(|m| m.server_name() == server_name) + .filter(|m| m != target) + .count(); + if count < 2 { + warn!("Last admin cannot leave from admins room"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Last admin cannot leave from admins room.", + )); + } + } + + if content.membership == MembershipState::Ban && pdu.state_key().is_some() { + if target == server_user { + warn!("Conduit user cannot be banned in admins room"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Conduit user cannot be banned in admins room.", + )); + } + + let count = services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(std::result::Result::ok) + .filter(|m| m.server_name() == server_name) + .filter(|m| m != target) + .count(); + if count < 2 { + warn!("Last admin cannot be banned in admins room"); + return Err(Error::BadRequest( + ErrorKind::Forbidden, + "Last admin cannot be banned in admins room.", + )); + } + } + }, + _ => {}, + } + } + + // We append to state before appending the pdu, so we don't have a moment in + // time with the pdu without it's state. This is okay because append_pdu can't + // fail. + let statehashid = services().rooms.state.append_to_state(&pdu)?; + + let pdu_id = self + .append_pdu( + &pdu, + pdu_json, + // Since this PDU references all pdu_leaves we can update the leaves + // of the room + vec![(*pdu.event_id).to_owned()], + state_lock, + ) + .await?; + + // We set the room state after inserting the pdu, so that we never have a moment + // in time where events in the current room state do not exist + services().rooms.state.set_room_state(room_id, statehashid, state_lock)?; + + let mut servers: HashSet = + services().rooms.state_cache.room_servers(room_id).filter_map(std::result::Result::ok).collect(); + + // In case we are kicking or banning a user, we need to inform their server of + // the change + if pdu.kind == TimelineEventType::RoomMember { + if let Some(state_key_uid) = + &pdu.state_key.as_ref().and_then(|state_key| UserId::parse(state_key.as_str()).ok()) + { + servers.insert(state_key_uid.server_name().to_owned()); + } + } + + // Remove our server from the server list since it will be added to it by + // room_servers() and/or the if statement above + servers.remove(services().globals.server_name()); + + services().sending.send_pdu(servers.into_iter(), &pdu_id)?; + + Ok(pdu.event_id) + } + + /// Append the incoming event setting the state snapshot to the state from + /// the server that sent the event. + #[tracing::instrument(skip_all)] + pub async fn append_incoming_pdu( + &self, + pdu: &PduEvent, + pdu_json: CanonicalJsonObject, + new_room_leaves: Vec, + state_ids_compressed: Arc>, + soft_fail: bool, + state_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex + ) -> Result>> { + // We append to state before appending the pdu, so we don't have a moment in + // time with the pdu without it's state. This is okay because append_pdu can't + // fail. + services().rooms.state.set_event_state(&pdu.event_id, &pdu.room_id, state_ids_compressed)?; + + if soft_fail { + services().rooms.pdu_metadata.mark_as_referenced(&pdu.room_id, &pdu.prev_events)?; + services().rooms.state.set_forward_extremities(&pdu.room_id, new_room_leaves, state_lock)?; + return Ok(None); + } + + let pdu_id = services().rooms.timeline.append_pdu(pdu, pdu_json, new_room_leaves, state_lock).await?; + + Ok(Some(pdu_id)) + } + + /// Returns an iterator over all PDUs in a room. + pub fn all_pdus<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, + ) -> Result> + 'a> { + self.pdus_after(user_id, room_id, PduCount::min()) + } + + /// Returns an iterator over all events and their tokens in a room that + /// happened before the event with id `until` in reverse-chronological + /// order. + #[tracing::instrument(skip(self))] + pub fn pdus_until<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, until: PduCount, + ) -> Result> + 'a> { + self.db.pdus_until(user_id, room_id, until) + } + + /// Returns an iterator over all events and their token in a room that + /// happened after the event with id `from` in chronological order. + #[tracing::instrument(skip(self))] + pub fn pdus_after<'a>( + &'a self, user_id: &UserId, room_id: &RoomId, from: PduCount, + ) -> Result> + 'a> { + self.db.pdus_after(user_id, room_id, from) + } + + /// Replace a PDU with the redacted form. + #[tracing::instrument(skip(self, reason))] + pub fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent) -> Result<()> { + // TODO: Don't reserialize, keep original json + if let Some(pdu_id) = self.get_pdu_id(event_id)? { + let mut pdu = + self.get_pdu_from_id(&pdu_id)?.ok_or_else(|| Error::bad_database("PDU ID points to invalid PDU."))?; + let room_version_id = services().rooms.state.get_room_version(&pdu.room_id)?; + pdu.redact(room_version_id, reason)?; + self.replace_pdu( + &pdu_id, + &utils::to_canonical_object(&pdu).map_err(|e| { + error!("Failed to convert PDU to canonical JSON: {}", e); + Error::bad_database("Failed to convert PDU to canonical JSON.") + })?, + &pdu, + )?; + } + // If event does not exist, just noop + Ok(()) + } + + #[tracing::instrument(skip(self, room_id))] + pub async fn backfill_if_required(&self, room_id: &RoomId, from: PduCount) -> Result<()> { + let first_pdu = + self.all_pdus(user_id!("@doesntmatter:conduit.rs"), room_id)?.next().expect("Room is not empty")?; + + if first_pdu.0 < from { + // No backfill required, there are still events between them + return Ok(()); + } + + let power_levels: RoomPowerLevelsEventContent = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + .map(|ev| { + serde_json::from_str(ev.content.get()) + .map_err(|_| Error::bad_database("invalid m.room.power_levels event")) + }) + .transpose()? + .unwrap_or_default(); + let mut admin_servers = power_levels + .users + .iter() + .filter(|(_, level)| **level > power_levels.users_default) + .map(|(user_id, _)| user_id.server_name()) + .collect::>(); + admin_servers.remove(services().globals.server_name()); + + // Request backfill + for backfill_server in admin_servers { + info!("Asking {backfill_server} for backfill"); + let response = services() + .sending + .send_federation_request( + backfill_server, + federation::backfill::get_backfill::v1::Request { + room_id: room_id.to_owned(), + v: vec![first_pdu.1.event_id.as_ref().to_owned()], + limit: uint!(100), + }, + ) + .await; + match response { + Ok(response) => { + let pub_key_map = RwLock::new(BTreeMap::new()); + for pdu in response.pdus { + if let Err(e) = self.backfill_pdu(backfill_server, pdu, &pub_key_map).await { + warn!("Failed to add backfilled pdu: {e}"); + } + } + return Ok(()); + }, + Err(e) => { + warn!("{backfill_server} could not provide backfill: {e}"); + }, + } + } + + info!("No servers could backfill"); + Ok(()) + } + + #[tracing::instrument(skip(self, pdu))] + pub async fn backfill_pdu( + &self, origin: &ServerName, pdu: Box, + pub_key_map: &RwLock>>, + ) -> Result<()> { + let (event_id, value, room_id) = server_server::parse_incoming_pdu(&pdu)?; + + // Lock so we cannot backfill the same pdu twice at the same time + let mutex = + Arc::clone(services().globals.roomid_mutex_federation.write().unwrap().entry(room_id.clone()).or_default()); + let mutex_lock = mutex.lock().await; + + // Skip the PDU if we already have it as a timeline event + if let Some(pdu_id) = services().rooms.timeline.get_pdu_id(&event_id)? { + info!("We already know {event_id} at {pdu_id:?}"); + return Ok(()); + } + + services().rooms.event_handler.fetch_required_signing_keys([&value], pub_key_map).await?; + + services() + .rooms + .event_handler + .handle_incoming_pdu(origin, &event_id, &room_id, value, false, pub_key_map) + .await?; + + let value = self.get_pdu_json(&event_id)?.expect("We just created it"); + let pdu = self.get_pdu(&event_id)?.expect("We just created it"); + + let shortroomid = services().rooms.short.get_shortroomid(&room_id)?.expect("room exists"); + + let mutex_insert = + Arc::clone(services().globals.roomid_mutex_insert.write().unwrap().entry(room_id.clone()).or_default()); + let insert_lock = mutex_insert.lock().await; + + let count = services().globals.next_count()?; + let mut pdu_id = shortroomid.to_be_bytes().to_vec(); + pdu_id.extend_from_slice(&0_u64.to_be_bytes()); + pdu_id.extend_from_slice(&(u64::MAX - count).to_be_bytes()); + + // Insert pdu + self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value)?; + + drop(insert_lock); + + if pdu.kind == TimelineEventType::RoomMessage { + #[derive(Deserialize)] + struct ExtractBody { + body: Option, + } + + let content = serde_json::from_str::(pdu.content.get()) + .map_err(|_| Error::bad_database("Invalid content in pdu."))?; + + if let Some(body) = content.body { + services().rooms.search.index_pdu(shortroomid, &pdu_id, &body)?; + } + } + drop(mutex_lock); + + info!("Prepended backfill pdu"); + Ok(()) + } } #[cfg(test)] mod tests { - use super::*; + use super::*; - #[test] - fn comparisons() { - assert!(PduCount::Normal(1) < PduCount::Normal(2)); - assert!(PduCount::Backfilled(2) < PduCount::Backfilled(1)); - assert!(PduCount::Normal(1) > PduCount::Backfilled(1)); - assert!(PduCount::Backfilled(1) < PduCount::Normal(1)); - } + #[test] + fn comparisons() { + assert!(PduCount::Normal(1) < PduCount::Normal(2)); + assert!(PduCount::Backfilled(2) < PduCount::Backfilled(1)); + assert!(PduCount::Normal(1) > PduCount::Backfilled(1)); + assert!(PduCount::Backfilled(1) < PduCount::Normal(1)); + } } diff --git a/src/service/rooms/user/data.rs b/src/service/rooms/user/data.rs index 4b8a4eca..2fd1c29e 100644 --- a/src/service/rooms/user/data.rs +++ b/src/service/rooms/user/data.rs @@ -1,27 +1,22 @@ -use crate::Result; use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; +use crate::Result; + pub trait Data: Send + Sync { - fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; + fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; - fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; - fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; + fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result; - // Returns the count at which the last reset_notification_counts was called - fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result; + // Returns the count at which the last reset_notification_counts was called + fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result; - fn associate_token_shortstatehash( - &self, - room_id: &RoomId, - token: u64, - shortstatehash: u64, - ) -> Result<()>; + fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()>; - fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result>; + fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result>; - fn get_shared_rooms<'a>( - &'a self, - users: Vec, - ) -> Result> + 'a>>; + fn get_shared_rooms<'a>( + &'a self, users: Vec, + ) -> Result> + 'a>>; } diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index 672e502d..7cf49d9e 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -6,44 +6,35 @@ use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId}; use crate::Result; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - self.db.reset_notification_counts(user_id, room_id) - } + pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { + self.db.reset_notification_counts(user_id, room_id) + } - pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.notification_count(user_id, room_id) - } + pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + self.db.notification_count(user_id, room_id) + } - pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.highlight_count(user_id, room_id) - } + pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result { + self.db.highlight_count(user_id, room_id) + } - pub fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { - self.db.last_notification_read(user_id, room_id) - } + pub fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result { + self.db.last_notification_read(user_id, room_id) + } - pub fn associate_token_shortstatehash( - &self, - room_id: &RoomId, - token: u64, - shortstatehash: u64, - ) -> Result<()> { - self.db - .associate_token_shortstatehash(room_id, token, shortstatehash) - } + pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> { + self.db.associate_token_shortstatehash(room_id, token, shortstatehash) + } - pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { - self.db.get_token_shortstatehash(room_id, token) - } + pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result> { + self.db.get_token_shortstatehash(room_id, token) + } - pub fn get_shared_rooms( - &self, - users: Vec, - ) -> Result>> { - self.db.get_shared_rooms(users) - } + pub fn get_shared_rooms(&self, users: Vec) -> Result>> { + self.db.get_shared_rooms(users) + } } diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index 427ee939..46f3cd71 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -1,28 +1,22 @@ use ruma::ServerName; +use super::{OutgoingKind, SendingEventType}; use crate::Result; -use super::{OutgoingKind, SendingEventType}; - -type OutgoingSendingIter<'a> = - Box, OutgoingKind, SendingEventType)>> + 'a>; +type OutgoingSendingIter<'a> = Box, OutgoingKind, SendingEventType)>> + 'a>; type SendingEventTypeIter<'a> = Box, SendingEventType)>> + 'a>; pub trait Data: Send + Sync { - fn active_requests(&self) -> OutgoingSendingIter<'_>; - fn active_requests_for(&self, outgoing_kind: &OutgoingKind) -> SendingEventTypeIter<'_>; - fn delete_active_request(&self, key: Vec) -> Result<()>; - fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; - fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; - fn queue_requests( - &self, - requests: &[(&OutgoingKind, SendingEventType)], - ) -> Result>>; - fn queued_requests<'a>( - &'a self, - outgoing_kind: &OutgoingKind, - ) -> Box)>> + 'a>; - fn mark_as_active(&self, events: &[(SendingEventType, Vec)]) -> Result<()>; - fn set_latest_educount(&self, server_name: &ServerName, educount: u64) -> Result<()>; - fn get_latest_educount(&self, server_name: &ServerName) -> Result; + fn active_requests(&self) -> OutgoingSendingIter<'_>; + fn active_requests_for(&self, outgoing_kind: &OutgoingKind) -> SendingEventTypeIter<'_>; + fn delete_active_request(&self, key: Vec) -> Result<()>; + fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; + fn delete_all_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()>; + fn queue_requests(&self, requests: &[(&OutgoingKind, SendingEventType)]) -> Result>>; + fn queued_requests<'a>( + &'a self, outgoing_kind: &OutgoingKind, + ) -> Box)>> + 'a>; + fn mark_as_active(&self, events: &[(SendingEventType, Vec)]) -> Result<()>; + fn set_latest_educount(&self, server_name: &ServerName, educount: u64) -> Result<()>; + fn get_latest_educount(&self, server_name: &ServerName) -> Result; } diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 7a8ded14..86de6f1c 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -1,797 +1,709 @@ mod data; -pub use data::Data; -use ipaddress::IPAddress; - use std::{ - collections::{BTreeMap, HashMap, HashSet}, - fmt::Debug, - sync::Arc, - time::{Duration, Instant}, + collections::{BTreeMap, HashMap, HashSet}, + fmt::Debug, + sync::Arc, + time::{Duration, Instant}, }; -use crate::{ - api::{appservice_server, server_server}, - services, - utils::calculate_hash, - Config, Error, PduEvent, Result, -}; -use federation::transactions::send_transaction_message; -use futures_util::{stream::FuturesUnordered, StreamExt}; - use base64::{engine::general_purpose, Engine as _}; - +pub use data::Data; +use federation::transactions::send_transaction_message; +use futures_util::{stream::FuturesUnordered, StreamExt}; +use ipaddress::IPAddress; use ruma::{ - api::{ - appservice::{self, Registration}, - federation::{ - self, - transactions::edu::{ - DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, - ReceiptData, ReceiptMap, - }, - }, - OutgoingRequest, - }, - device_id, - events::{ - push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, - GlobalAccountDataEventType, - }, - push, uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, ServerName, UInt, UserId, + api::{ + appservice::{self, Registration}, + federation::{ + self, + transactions::edu::{ + DeviceListUpdateContent, Edu, PresenceContent, PresenceUpdate, ReceiptContent, ReceiptData, ReceiptMap, + }, + }, + OutgoingRequest, + }, + device_id, + events::{push_rules::PushRulesEvent, receipt::ReceiptType, AnySyncEphemeralRoomEvent, GlobalAccountDataEventType}, + push, uint, MilliSecondsSinceUnixEpoch, OwnedServerName, OwnedUserId, ServerName, UInt, UserId, }; use tokio::{ - select, - sync::{mpsc, Mutex, Semaphore}, + select, + sync::{mpsc, Mutex, Semaphore}, }; use tracing::{debug, error, info, warn}; +use crate::{ + api::{appservice_server, server_server}, + services, + utils::calculate_hash, + Config, Error, PduEvent, Result, +}; + #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum OutgoingKind { - Appservice(String), - Push(OwnedUserId, String), // user and pushkey - Normal(OwnedServerName), + Appservice(String), + Push(OwnedUserId, String), // user and pushkey + Normal(OwnedServerName), } impl OutgoingKind { - #[tracing::instrument(skip(self))] - pub fn get_prefix(&self) -> Vec { - let mut prefix = match self { - OutgoingKind::Appservice(server) => { - let mut p = b"+".to_vec(); - p.extend_from_slice(server.as_bytes()); - p - } - OutgoingKind::Push(user, pushkey) => { - let mut p = b"$".to_vec(); - p.extend_from_slice(user.as_bytes()); - p.push(0xff); - p.extend_from_slice(pushkey.as_bytes()); - p - } - OutgoingKind::Normal(server) => { - let mut p = Vec::new(); - p.extend_from_slice(server.as_bytes()); - p - } - }; - prefix.push(0xff); + #[tracing::instrument(skip(self))] + pub fn get_prefix(&self) -> Vec { + let mut prefix = match self { + OutgoingKind::Appservice(server) => { + let mut p = b"+".to_vec(); + p.extend_from_slice(server.as_bytes()); + p + }, + OutgoingKind::Push(user, pushkey) => { + let mut p = b"$".to_vec(); + p.extend_from_slice(user.as_bytes()); + p.push(0xFF); + p.extend_from_slice(pushkey.as_bytes()); + p + }, + OutgoingKind::Normal(server) => { + let mut p = Vec::new(); + p.extend_from_slice(server.as_bytes()); + p + }, + }; + prefix.push(0xFF); - prefix - } + prefix + } } #[derive(Clone, Debug, PartialEq, Eq, Hash)] #[allow(clippy::module_name_repetitions)] pub enum SendingEventType { - Pdu(Vec), // pduid - Edu(Vec), // pdu json + Pdu(Vec), // pduid + Edu(Vec), // pdu json } pub struct Service { - db: &'static dyn Data, + db: &'static dyn Data, - /// The state for a given state hash. - pub(super) maximum_requests: Arc, - pub sender: mpsc::UnboundedSender<(OutgoingKind, SendingEventType, Vec)>, - receiver: Mutex)>>, + /// The state for a given state hash. + pub(super) maximum_requests: Arc, + pub sender: mpsc::UnboundedSender<(OutgoingKind, SendingEventType, Vec)>, + receiver: Mutex)>>, } enum TransactionStatus { - Running, - Failed(u32, Instant), // number of times failed, time of last failure - Retrying(u32), // number of times failed + Running, + Failed(u32, Instant), // number of times failed, time of last failure + Retrying(u32), // number of times failed } impl Service { - pub fn build(db: &'static dyn Data, config: &Config) -> Arc { - let (sender, receiver) = mpsc::unbounded_channel(); - Arc::new(Self { - db, - sender, - receiver: Mutex::new(receiver), - maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), - }) - } - - pub fn start_handler(self: &Arc) { - let self2 = Arc::clone(self); - tokio::spawn(async move { - self2.handler().await.unwrap(); - }); - } - - async fn handler(&self) -> Result<()> { - let mut receiver = self.receiver.lock().await; - - let mut futures = FuturesUnordered::new(); - - let mut current_transaction_status = HashMap::::new(); - - // Retry requests we could not finish yet - let mut initial_transactions = HashMap::>::new(); - - for (key, outgoing_kind, event) in self - .db - .active_requests() - .filter_map(std::result::Result::ok) - { - let entry = initial_transactions - .entry(outgoing_kind.clone()) - .or_default(); - - if entry.len() > 30 { - warn!( - "Dropping some current events: {:?} {:?} {:?}", - key, outgoing_kind, event - ); - self.db.delete_active_request(key)?; - continue; - } - - entry.push(event); - } - - for (outgoing_kind, events) in initial_transactions { - current_transaction_status.insert(outgoing_kind.clone(), TransactionStatus::Running); - futures.push(Self::handle_events(outgoing_kind.clone(), events)); - } - - loop { - select! { - Some(response) = futures.next() => { - match response { - Ok(outgoing_kind) => { - self.db.delete_all_active_requests_for(&outgoing_kind)?; - - // Find events that have been added since starting the last request - let new_events = self.db.queued_requests(&outgoing_kind).filter_map(std::result::Result::ok).take(30).collect::>(); - - if !new_events.is_empty() { - // Insert pdus we found - self.db.mark_as_active(&new_events)?; - - futures.push( - Self::handle_events( - outgoing_kind.clone(), - new_events.into_iter().map(|(event, _)| event).collect(), - ) - ); - } else { - current_transaction_status.remove(&outgoing_kind); - } - } - Err((outgoing_kind, _)) => { - current_transaction_status.entry(outgoing_kind).and_modify(|e| *e = match e { - TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()), - TransactionStatus::Retrying(n) => TransactionStatus::Failed(*n+1, Instant::now()), - TransactionStatus::Failed(_, _) => { - error!("Request that was not even running failed?!"); - return - }, - }); - } - }; - }, - Some((outgoing_kind, event, key)) = receiver.recv() => { - if let Ok(Some(events)) = self.select_events( - &outgoing_kind, - vec![(event, key)], - &mut current_transaction_status, - ) { - futures.push(Self::handle_events(outgoing_kind, events)); - } - } - } - } - } - - #[tracing::instrument(skip(self, outgoing_kind, new_events, current_transaction_status))] - fn select_events( - &self, - outgoing_kind: &OutgoingKind, - new_events: Vec<(SendingEventType, Vec)>, // Events we want to send: event and full key - current_transaction_status: &mut HashMap, - ) -> Result>> { - let mut retry = false; - let mut allow = true; - - let entry = current_transaction_status.entry(outgoing_kind.clone()); - - entry - .and_modify(|e| match e { - TransactionStatus::Running | TransactionStatus::Retrying(_) => { - allow = false; // already running - } - TransactionStatus::Failed(tries, time) => { - // Fail if a request has failed recently (exponential backoff) - let mut min_elapsed_duration = - Duration::from_secs(5 * 60) * (*tries) * (*tries); - if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { - min_elapsed_duration = Duration::from_secs(60 * 60 * 24); - } - - if time.elapsed() < min_elapsed_duration { - allow = false; - } else { - retry = true; - *e = TransactionStatus::Retrying(*tries); - } - } - }) - .or_insert(TransactionStatus::Running); - - if !allow { - return Ok(None); - } - - let mut events = Vec::new(); - - if retry { - // We retry the previous transaction - for (_, e) in self - .db - .active_requests_for(outgoing_kind) - .filter_map(std::result::Result::ok) - { - events.push(e); - } - } else { - self.db.mark_as_active(&new_events)?; - for (e, _) in new_events { - events.push(e); - } - - if let OutgoingKind::Normal(server_name) = outgoing_kind { - if let Ok((select_edus, last_count)) = self.select_edus(server_name) { - events.extend(select_edus.into_iter().map(SendingEventType::Edu)); - - self.db.set_latest_educount(server_name, last_count)?; - } - } - } - - Ok(Some(events)) - } - - #[tracing::instrument(skip(self, server_name))] - pub fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { - // u64: count of last edu - let since = self.db.get_latest_educount(server_name)?; - let mut events = Vec::new(); - let mut max_edu_count = since; - let mut device_list_changes = HashSet::new(); - - 'outer: for room_id in services().rooms.state_cache.server_rooms(server_name) { - let room_id = room_id?; - // Look for device list updates in this room - device_list_changes.extend( - services() - .users - .keys_changed(room_id.as_ref(), since, None) - .filter_map(std::result::Result::ok) - .filter(|user_id| user_id.server_name() == services().globals.server_name()), - ); - - if services().globals.allow_outgoing_presence() { - // Look for presence updates in this room - let mut presence_updates = Vec::new(); - - for (user_id, count, presence_event) in services() - .rooms - .edus - .presence - .presence_since(&room_id, since) - { - if count > max_edu_count { - max_edu_count = count; - } - - if user_id.server_name() != services().globals.server_name() { - continue; - } - - presence_updates.push(PresenceUpdate { - user_id, - presence: presence_event.content.presence, - currently_active: presence_event.content.currently_active.unwrap_or(false), - last_active_ago: presence_event - .content - .last_active_ago - .unwrap_or_else(|| uint!(0)), - status_msg: presence_event.content.status_msg, - }); - } - - let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); - events.push( - serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized"), - ); - } - - // Look for read receipts in this room - for r in services() - .rooms - .edus - .read_receipt - .readreceipts_since(&room_id, since) - { - let (user_id, count, read_receipt) = r?; - - if count > max_edu_count { - max_edu_count = count; - } - - if user_id.server_name() != services().globals.server_name() { - continue; - } - - let event: AnySyncEphemeralRoomEvent = - serde_json::from_str(read_receipt.json().get()) - .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; - let federation_event = match event { - AnySyncEphemeralRoomEvent::Receipt(r) => { - let mut read = BTreeMap::new(); - - let (event_id, mut receipt) = r - .content - .0 - .into_iter() - .next() - .expect("we only use one event per read receipt"); - let receipt = receipt - .remove(&ReceiptType::Read) - .expect("our read receipts always set this") - .remove(&user_id) - .expect("our read receipts always have the user here"); - - read.insert( - user_id, - ReceiptData { - data: receipt.clone(), - event_ids: vec![event_id.clone()], - }, - ); - - let receipt_map = ReceiptMap { read }; - - let mut receipts = BTreeMap::new(); - receipts.insert(room_id.clone(), receipt_map); - - Edu::Receipt(ReceiptContent { receipts }) - } - _ => { - Error::bad_database("Invalid event type in read_receipts"); - continue; - } - }; - - events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); - - if events.len() >= 20 { - break 'outer; - } - } - } - - for user_id in device_list_changes { - // Empty prev id forces synapse to resync: https://github.com/matrix-org/synapse/blob/98aec1cc9da2bd6b8e34ffb282c85abf9b8b42ca/synapse/handlers/device.py#L767 - // Because synapse resyncs, we can just insert dummy data - let edu = Edu::DeviceListUpdate(DeviceListUpdateContent { - user_id, - device_id: device_id!("dummy").to_owned(), - device_display_name: Some("Dummy".to_owned()), - stream_id: uint!(1), - prev_id: Vec::new(), - deleted: None, - keys: None, - }); - - events.push(serde_json::to_vec(&edu).expect("json can be serialized")); - } - - Ok((events, max_edu_count)) - } - - #[tracing::instrument(skip(self, pdu_id, user, pushkey))] - pub fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { - let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); - let event = SendingEventType::Pdu(pdu_id.to_owned()); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; - self.sender - .send((outgoing_kind, event, keys.into_iter().next().unwrap())) - .unwrap(); - - Ok(()) - } - - #[tracing::instrument(skip(self, servers, pdu_id))] - pub fn send_pdu>( - &self, - servers: I, - pdu_id: &[u8], - ) -> Result<()> { - let requests = servers - .into_iter() - .map(|server| { - ( - OutgoingKind::Normal(server), - SendingEventType::Pdu(pdu_id.to_owned()), - ) - }) - .collect::>(); - let keys = self.db.queue_requests( - &requests - .iter() - .map(|(o, e)| (o, e.clone())) - .collect::>(), - )?; - for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { - self.sender - .send((outgoing_kind.clone(), event, key)) - .unwrap(); - } - - Ok(()) - } - - #[tracing::instrument(skip(self, server, serialized))] - pub fn send_reliable_edu( - &self, - server: &ServerName, - serialized: Vec, - id: u64, - ) -> Result<()> { - let outgoing_kind = OutgoingKind::Normal(server.to_owned()); - let event = SendingEventType::Edu(serialized); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; - self.sender - .send((outgoing_kind, event, keys.into_iter().next().unwrap())) - .unwrap(); - - Ok(()) - } - - #[tracing::instrument(skip(self))] - pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { - let outgoing_kind = OutgoingKind::Appservice(appservice_id); - let event = SendingEventType::Pdu(pdu_id); - let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; - self.sender - .send((outgoing_kind, event, keys.into_iter().next().unwrap())) - .unwrap(); - - Ok(()) - } - - /// Cleanup event data - /// Used for instance after we remove an appservice registration - /// - #[tracing::instrument(skip(self))] - pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { - self.db - .delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?; - - Ok(()) - } - - #[tracing::instrument(skip(events, kind))] - async fn handle_events( - kind: OutgoingKind, - events: Vec, - ) -> Result { - match &kind { - OutgoingKind::Appservice(id) => { - let mut pdu_jsons = Vec::new(); - - for event in &events { - match event { - SendingEventType::Pdu(pdu_id) => { - pdu_jsons.push(services().rooms.timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (kind.clone(), e))? - .ok_or_else(|| { - ( - kind.clone(), - Error::bad_database( - "[Appservice] Event in servernameevent_data not found in db.", - ), - ) - })? - .to_room_event()); - } - SendingEventType::Edu(_) => { - // Appservices don't need EDUs (?) - } - } - } - - let permit = services().sending.maximum_requests.acquire().await; - - let response = match appservice_server::send_request( - services() - .appservice - .get_registration(id) - .map_err(|e| (kind.clone(), e))? - .ok_or_else(|| { - ( - kind.clone(), - Error::bad_database( - "[Appservice] Could not load registration from db.", - ), - ) - })?, - appservice::event::push_events::v1::Request { - events: pdu_jsons, - txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, - }) - .collect::>(), - ))) - .into(), - }, - ) - .await - { - None => Ok(kind.clone()), - Some(op_resp) => op_resp - .map(|_response| kind.clone()) - .map_err(|e| (kind.clone(), e)), - }; - - drop(permit); - - response - } - OutgoingKind::Push(userid, pushkey) => { - let mut pdus = Vec::new(); - - for event in &events { - match event { - SendingEventType::Pdu(pdu_id) => { - pdus.push( - services().rooms - .timeline - .get_pdu_from_id(pdu_id) - .map_err(|e| (kind.clone(), e))? - .ok_or_else(|| { - ( - kind.clone(), - Error::bad_database( - "[Push] Event in servernamevent_datas not found in db.", - ), - ) - })?, - ); - } - SendingEventType::Edu(_) => { - // Push gateways don't need EDUs (?) - } - } - } - - for pdu in pdus { - // Redacted events are not notification targets (we don't send push for them) - if let Some(unsigned) = &pdu.unsigned { - if let Ok(unsigned) = - serde_json::from_str::(unsigned.get()) - { - if unsigned.get("redacted_because").is_some() { - continue; - } - } - } - - let Some(pusher) = services() - .pusher - .get_pusher(userid, pushkey) - .map_err(|e| (OutgoingKind::Push(userid.clone(), pushkey.clone()), e))? - else { - continue; - }; - - let rules_for_user = services() - .account_data - .get( - None, - userid, - GlobalAccountDataEventType::PushRules.to_string().into(), - ) - .unwrap_or_default() - .and_then(|event| serde_json::from_str::(event.get()).ok()) - .map_or_else( - || push::Ruleset::server_default(userid), - |ev: PushRulesEvent| ev.content.global, - ); - - let unread: UInt = services() - .rooms - .user - .notification_count(userid, &pdu.room_id) - .map_err(|e| (kind.clone(), e))? - .try_into() - .expect("notification count can't go that high"); - - let permit = services().sending.maximum_requests.acquire().await; - - let _response = services() - .pusher - .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) - .await - .map(|_response| kind.clone()) - .map_err(|e| (kind.clone(), e)); - - drop(permit); - } - Ok(OutgoingKind::Push(userid.clone(), pushkey.clone())) - } - OutgoingKind::Normal(server) => { - let mut edu_jsons = Vec::new(); - let mut pdu_jsons = Vec::new(); - - for event in &events { - match event { - SendingEventType::Pdu(pdu_id) => { - // TODO: check room version and remove event_id if needed - let raw = PduEvent::convert_to_outgoing_federation_event( - services().rooms - .timeline - .get_pdu_json_from_id(pdu_id) - .map_err(|e| (OutgoingKind::Normal(server.clone()), e))? - .ok_or_else(|| { - error!("event not found: {server} {pdu_id:?}"); - ( - OutgoingKind::Normal(server.clone()), - Error::bad_database( - "[Normal] Event in servernamevent_datas not found in db.", - ), - ) - })?, - ); - pdu_jsons.push(raw); - } - SendingEventType::Edu(edu) => { - if let Ok(raw) = serde_json::from_slice(edu) { - edu_jsons.push(raw); - } - } - } - } - - let permit = services().sending.maximum_requests.acquire().await; - - let response = server_server::send_request( - server, - send_transaction_message::v1::Request { - origin: services().globals.server_name().to_owned(), - pdus: pdu_jsons, - edus: edu_jsons, - origin_server_ts: MilliSecondsSinceUnixEpoch::now(), - transaction_id: (&*general_purpose::URL_SAFE_NO_PAD.encode( - calculate_hash( - &events - .iter() - .map(|e| match e { - SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, - }) - .collect::>(), - ), - )) - .into(), - }, - ) - .await - .map(|response| { - for pdu in response.pdus { - if pdu.1.is_err() { - warn!("Failed to send to {}: {:?}", server, pdu); - } - } - kind.clone() - }) - .map_err(|e| (kind, e)); - - drop(permit); - - response - } - } - } - - #[tracing::instrument(skip(self, destination, request))] - pub async fn send_federation_request( - &self, - destination: &ServerName, - request: T, - ) -> Result - where - T: OutgoingRequest + Debug, - { - if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) { - info!( - "Destination {} is an IP literal, checking against IP range denylist.", - destination - ); - let ip = IPAddress::parse(destination.host()).map_err(|e| { - warn!("Failed to parse IP literal from string: {}", e); - Error::BadServerResponse("Invalid IP address") - })?; - - let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); - let mut cidr_ranges: Vec = Vec::new(); - - for cidr in cidr_ranges_s { - cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); - } - - debug!("List of pushed CIDR ranges: {:?}", cidr_ranges); - - for cidr in cidr_ranges { - if cidr.includes(&ip) { - return Err(Error::BadServerResponse( - "Not allowed to send requests to this IP", - )); - } - } - - info!("IP literal {} is allowed.", destination); - } - - debug!("Waiting for permit"); - let permit = self.maximum_requests.acquire().await; - debug!("Got permit"); - let response = tokio::time::timeout( - Duration::from_secs(5 * 60), - server_server::send_request(destination, request), - ) - .await - .map_err(|_| { - warn!("Timeout after 300 seconds waiting for server response of {destination}"); - Error::BadServerResponse("Timeout after 300 seconds waiting for server response") - })?; - drop(permit); - - response - } - - /// Sends a request to an appservice - /// - /// Only returns None if there is no url specified in the appservice registration file - pub async fn send_appservice_request( - &self, - registration: Registration, - request: T, - ) -> Option> - where - T: OutgoingRequest + Debug, - { - let permit = self.maximum_requests.acquire().await; - let response = appservice_server::send_request(registration, request).await; - drop(permit); - - response - } + pub fn build(db: &'static dyn Data, config: &Config) -> Arc { + let (sender, receiver) = mpsc::unbounded_channel(); + Arc::new(Self { + db, + sender, + receiver: Mutex::new(receiver), + maximum_requests: Arc::new(Semaphore::new(config.max_concurrent_requests as usize)), + }) + } + + pub fn start_handler(self: &Arc) { + let self2 = Arc::clone(self); + tokio::spawn(async move { + self2.handler().await.unwrap(); + }); + } + + async fn handler(&self) -> Result<()> { + let mut receiver = self.receiver.lock().await; + + let mut futures = FuturesUnordered::new(); + + let mut current_transaction_status = HashMap::::new(); + + // Retry requests we could not finish yet + let mut initial_transactions = HashMap::>::new(); + + for (key, outgoing_kind, event) in self.db.active_requests().filter_map(std::result::Result::ok) { + let entry = initial_transactions.entry(outgoing_kind.clone()).or_default(); + + if entry.len() > 30 { + warn!("Dropping some current events: {:?} {:?} {:?}", key, outgoing_kind, event); + self.db.delete_active_request(key)?; + continue; + } + + entry.push(event); + } + + for (outgoing_kind, events) in initial_transactions { + current_transaction_status.insert(outgoing_kind.clone(), TransactionStatus::Running); + futures.push(Self::handle_events(outgoing_kind.clone(), events)); + } + + loop { + select! { + Some(response) = futures.next() => { + match response { + Ok(outgoing_kind) => { + self.db.delete_all_active_requests_for(&outgoing_kind)?; + + // Find events that have been added since starting the last request + let new_events = self.db.queued_requests(&outgoing_kind).filter_map(std::result::Result::ok).take(30).collect::>(); + + if !new_events.is_empty() { + // Insert pdus we found + self.db.mark_as_active(&new_events)?; + + futures.push( + Self::handle_events( + outgoing_kind.clone(), + new_events.into_iter().map(|(event, _)| event).collect(), + ) + ); + } else { + current_transaction_status.remove(&outgoing_kind); + } + } + Err((outgoing_kind, _)) => { + current_transaction_status.entry(outgoing_kind).and_modify(|e| *e = match e { + TransactionStatus::Running => TransactionStatus::Failed(1, Instant::now()), + TransactionStatus::Retrying(n) => TransactionStatus::Failed(*n+1, Instant::now()), + TransactionStatus::Failed(_, _) => { + error!("Request that was not even running failed?!"); + return + }, + }); + } + }; + }, + Some((outgoing_kind, event, key)) = receiver.recv() => { + if let Ok(Some(events)) = self.select_events( + &outgoing_kind, + vec![(event, key)], + &mut current_transaction_status, + ) { + futures.push(Self::handle_events(outgoing_kind, events)); + } + } + } + } + } + + #[tracing::instrument(skip(self, outgoing_kind, new_events, current_transaction_status))] + fn select_events( + &self, + outgoing_kind: &OutgoingKind, + new_events: Vec<(SendingEventType, Vec)>, // Events we want to send: event and full key + current_transaction_status: &mut HashMap, + ) -> Result>> { + let mut retry = false; + let mut allow = true; + + let entry = current_transaction_status.entry(outgoing_kind.clone()); + + entry + .and_modify(|e| match e { + TransactionStatus::Running | TransactionStatus::Retrying(_) => { + allow = false; // already running + }, + TransactionStatus::Failed(tries, time) => { + // Fail if a request has failed recently (exponential backoff) + let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries); + if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) { + min_elapsed_duration = Duration::from_secs(60 * 60 * 24); + } + + if time.elapsed() < min_elapsed_duration { + allow = false; + } else { + retry = true; + *e = TransactionStatus::Retrying(*tries); + } + }, + }) + .or_insert(TransactionStatus::Running); + + if !allow { + return Ok(None); + } + + let mut events = Vec::new(); + + if retry { + // We retry the previous transaction + for (_, e) in self.db.active_requests_for(outgoing_kind).filter_map(std::result::Result::ok) { + events.push(e); + } + } else { + self.db.mark_as_active(&new_events)?; + for (e, _) in new_events { + events.push(e); + } + + if let OutgoingKind::Normal(server_name) = outgoing_kind { + if let Ok((select_edus, last_count)) = self.select_edus(server_name) { + events.extend(select_edus.into_iter().map(SendingEventType::Edu)); + + self.db.set_latest_educount(server_name, last_count)?; + } + } + } + + Ok(Some(events)) + } + + #[tracing::instrument(skip(self, server_name))] + pub fn select_edus(&self, server_name: &ServerName) -> Result<(Vec>, u64)> { + // u64: count of last edu + let since = self.db.get_latest_educount(server_name)?; + let mut events = Vec::new(); + let mut max_edu_count = since; + let mut device_list_changes = HashSet::new(); + + 'outer: for room_id in services().rooms.state_cache.server_rooms(server_name) { + let room_id = room_id?; + // Look for device list updates in this room + device_list_changes.extend( + services() + .users + .keys_changed(room_id.as_ref(), since, None) + .filter_map(std::result::Result::ok) + .filter(|user_id| user_id.server_name() == services().globals.server_name()), + ); + + if services().globals.allow_outgoing_presence() { + // Look for presence updates in this room + let mut presence_updates = Vec::new(); + + for (user_id, count, presence_event) in services().rooms.edus.presence.presence_since(&room_id, since) { + if count > max_edu_count { + max_edu_count = count; + } + + if user_id.server_name() != services().globals.server_name() { + continue; + } + + presence_updates.push(PresenceUpdate { + user_id, + presence: presence_event.content.presence, + currently_active: presence_event.content.currently_active.unwrap_or(false), + last_active_ago: presence_event.content.last_active_ago.unwrap_or_else(|| uint!(0)), + status_msg: presence_event.content.status_msg, + }); + } + + let presence_content = Edu::Presence(PresenceContent::new(presence_updates)); + events.push(serde_json::to_vec(&presence_content).expect("PresenceEvent can be serialized")); + } + + // Look for read receipts in this room + for r in services().rooms.edus.read_receipt.readreceipts_since(&room_id, since) { + let (user_id, count, read_receipt) = r?; + + if count > max_edu_count { + max_edu_count = count; + } + + if user_id.server_name() != services().globals.server_name() { + continue; + } + + let event: AnySyncEphemeralRoomEvent = serde_json::from_str(read_receipt.json().get()) + .map_err(|_| Error::bad_database("Invalid edu event in read_receipts."))?; + let federation_event = match event { + AnySyncEphemeralRoomEvent::Receipt(r) => { + let mut read = BTreeMap::new(); + + let (event_id, mut receipt) = + r.content.0.into_iter().next().expect("we only use one event per read receipt"); + let receipt = receipt + .remove(&ReceiptType::Read) + .expect("our read receipts always set this") + .remove(&user_id) + .expect("our read receipts always have the user here"); + + read.insert( + user_id, + ReceiptData { + data: receipt.clone(), + event_ids: vec![event_id.clone()], + }, + ); + + let receipt_map = ReceiptMap { + read, + }; + + let mut receipts = BTreeMap::new(); + receipts.insert(room_id.clone(), receipt_map); + + Edu::Receipt(ReceiptContent { + receipts, + }) + }, + _ => { + Error::bad_database("Invalid event type in read_receipts"); + continue; + }, + }; + + events.push(serde_json::to_vec(&federation_event).expect("json can be serialized")); + + if events.len() >= 20 { + break 'outer; + } + } + } + + for user_id in device_list_changes { + // Empty prev id forces synapse to resync: https://github.com/matrix-org/synapse/blob/98aec1cc9da2bd6b8e34ffb282c85abf9b8b42ca/synapse/handlers/device.py#L767 + // Because synapse resyncs, we can just insert dummy data + let edu = Edu::DeviceListUpdate(DeviceListUpdateContent { + user_id, + device_id: device_id!("dummy").to_owned(), + device_display_name: Some("Dummy".to_owned()), + stream_id: uint!(1), + prev_id: Vec::new(), + deleted: None, + keys: None, + }); + + events.push(serde_json::to_vec(&edu).expect("json can be serialized")); + } + + Ok((events, max_edu_count)) + } + + #[tracing::instrument(skip(self, pdu_id, user, pushkey))] + pub fn send_push_pdu(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + let outgoing_kind = OutgoingKind::Push(user.to_owned(), pushkey); + let event = SendingEventType::Pdu(pdu_id.to_owned()); + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); + + Ok(()) + } + + #[tracing::instrument(skip(self, servers, pdu_id))] + pub fn send_pdu>(&self, servers: I, pdu_id: &[u8]) -> Result<()> { + let requests = servers + .into_iter() + .map(|server| (OutgoingKind::Normal(server), SendingEventType::Pdu(pdu_id.to_owned()))) + .collect::>(); + let keys = self.db.queue_requests(&requests.iter().map(|(o, e)| (o, e.clone())).collect::>())?; + for ((outgoing_kind, event), key) in requests.into_iter().zip(keys) { + self.sender.send((outgoing_kind.clone(), event, key)).unwrap(); + } + + Ok(()) + } + + #[tracing::instrument(skip(self, server, serialized))] + pub fn send_reliable_edu(&self, server: &ServerName, serialized: Vec, id: u64) -> Result<()> { + let outgoing_kind = OutgoingKind::Normal(server.to_owned()); + let event = SendingEventType::Edu(serialized); + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); + + Ok(()) + } + + #[tracing::instrument(skip(self))] + pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { + let outgoing_kind = OutgoingKind::Appservice(appservice_id); + let event = SendingEventType::Pdu(pdu_id); + let keys = self.db.queue_requests(&[(&outgoing_kind, event.clone())])?; + self.sender.send((outgoing_kind, event, keys.into_iter().next().unwrap())).unwrap(); + + Ok(()) + } + + /// Cleanup event data + /// Used for instance after we remove an appservice registration + #[tracing::instrument(skip(self))] + pub fn cleanup_events(&self, appservice_id: String) -> Result<()> { + self.db.delete_all_requests_for(&OutgoingKind::Appservice(appservice_id))?; + + Ok(()) + } + + #[tracing::instrument(skip(events, kind))] + async fn handle_events( + kind: OutgoingKind, events: Vec, + ) -> Result { + match &kind { + OutgoingKind::Appservice(id) => { + let mut pdu_jsons = Vec::new(); + + for event in &events { + match event { + SendingEventType::Pdu(pdu_id) => { + pdu_jsons.push( + services() + .rooms + .timeline + .get_pdu_from_id(pdu_id) + .map_err(|e| (kind.clone(), e))? + .ok_or_else(|| { + ( + kind.clone(), + Error::bad_database( + "[Appservice] Event in servernameevent_data not found in db.", + ), + ) + })? + .to_room_event(), + ); + }, + SendingEventType::Edu(_) => { + // Appservices don't need EDUs (?) + }, + } + } + + let permit = services().sending.maximum_requests.acquire().await; + + let response = match appservice_server::send_request( + services().appservice.get_registration(id).map_err(|e| (kind.clone(), e))?.ok_or_else(|| { + ( + kind.clone(), + Error::bad_database("[Appservice] Could not load registration from db."), + ) + })?, + appservice::event::push_events::v1::Request { + events: pdu_jsons, + txn_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, + }) + .collect::>(), + ))) + .into(), + }, + ) + .await + { + None => Ok(kind.clone()), + Some(op_resp) => op_resp.map(|_response| kind.clone()).map_err(|e| (kind.clone(), e)), + }; + + drop(permit); + + response + }, + OutgoingKind::Push(userid, pushkey) => { + let mut pdus = Vec::new(); + + for event in &events { + match event { + SendingEventType::Pdu(pdu_id) => { + pdus.push( + services() + .rooms + .timeline + .get_pdu_from_id(pdu_id) + .map_err(|e| (kind.clone(), e))? + .ok_or_else(|| { + ( + kind.clone(), + Error::bad_database( + "[Push] Event in servernamevent_datas not found in db.", + ), + ) + })?, + ); + }, + SendingEventType::Edu(_) => { + // Push gateways don't need EDUs (?) + }, + } + } + + for pdu in pdus { + // Redacted events are not notification targets (we don't send push for them) + if let Some(unsigned) = &pdu.unsigned { + if let Ok(unsigned) = serde_json::from_str::(unsigned.get()) { + if unsigned.get("redacted_because").is_some() { + continue; + } + } + } + + let Some(pusher) = services() + .pusher + .get_pusher(userid, pushkey) + .map_err(|e| (OutgoingKind::Push(userid.clone(), pushkey.clone()), e))? + else { + continue; + }; + + let rules_for_user = services() + .account_data + .get(None, userid, GlobalAccountDataEventType::PushRules.to_string().into()) + .unwrap_or_default() + .and_then(|event| serde_json::from_str::(event.get()).ok()) + .map_or_else(|| push::Ruleset::server_default(userid), |ev: PushRulesEvent| ev.content.global); + + let unread: UInt = services() + .rooms + .user + .notification_count(userid, &pdu.room_id) + .map_err(|e| (kind.clone(), e))? + .try_into() + .expect("notification count can't go that high"); + + let permit = services().sending.maximum_requests.acquire().await; + + let _response = services() + .pusher + .send_push_notice(userid, unread, &pusher, rules_for_user, &pdu) + .await + .map(|_response| kind.clone()) + .map_err(|e| (kind.clone(), e)); + + drop(permit); + } + Ok(OutgoingKind::Push(userid.clone(), pushkey.clone())) + }, + OutgoingKind::Normal(server) => { + let mut edu_jsons = Vec::new(); + let mut pdu_jsons = Vec::new(); + + for event in &events { + match event { + SendingEventType::Pdu(pdu_id) => { + // TODO: check room version and remove event_id if needed + let raw = PduEvent::convert_to_outgoing_federation_event( + services() + .rooms + .timeline + .get_pdu_json_from_id(pdu_id) + .map_err(|e| (OutgoingKind::Normal(server.clone()), e))? + .ok_or_else(|| { + error!("event not found: {server} {pdu_id:?}"); + ( + OutgoingKind::Normal(server.clone()), + Error::bad_database( + "[Normal] Event in servernamevent_datas not found in db.", + ), + ) + })?, + ); + pdu_jsons.push(raw); + }, + SendingEventType::Edu(edu) => { + if let Ok(raw) = serde_json::from_slice(edu) { + edu_jsons.push(raw); + } + }, + } + } + + let permit = services().sending.maximum_requests.acquire().await; + + let response = server_server::send_request( + server, + send_transaction_message::v1::Request { + origin: services().globals.server_name().to_owned(), + pdus: pdu_jsons, + edus: edu_jsons, + origin_server_ts: MilliSecondsSinceUnixEpoch::now(), + transaction_id: (&*general_purpose::URL_SAFE_NO_PAD.encode(calculate_hash( + &events + .iter() + .map(|e| match e { + SendingEventType::Edu(b) | SendingEventType::Pdu(b) => &**b, + }) + .collect::>(), + ))) + .into(), + }, + ) + .await + .map(|response| { + for pdu in response.pdus { + if pdu.1.is_err() { + warn!("Failed to send to {}: {:?}", server, pdu); + } + } + kind.clone() + }) + .map_err(|e| (kind, e)); + + drop(permit); + + response + }, + } + } + + #[tracing::instrument(skip(self, destination, request))] + pub async fn send_federation_request(&self, destination: &ServerName, request: T) -> Result + where + T: OutgoingRequest + Debug, + { + if destination.is_ip_literal() || IPAddress::is_valid(destination.host()) { + info!( + "Destination {} is an IP literal, checking against IP range denylist.", + destination + ); + let ip = IPAddress::parse(destination.host()).map_err(|e| { + warn!("Failed to parse IP literal from string: {}", e); + Error::BadServerResponse("Invalid IP address") + })?; + + let cidr_ranges_s = services().globals.ip_range_denylist().to_vec(); + let mut cidr_ranges: Vec = Vec::new(); + + for cidr in cidr_ranges_s { + cidr_ranges.push(IPAddress::parse(cidr).expect("we checked this at startup")); + } + + debug!("List of pushed CIDR ranges: {:?}", cidr_ranges); + + for cidr in cidr_ranges { + if cidr.includes(&ip) { + return Err(Error::BadServerResponse("Not allowed to send requests to this IP")); + } + } + + info!("IP literal {} is allowed.", destination); + } + + debug!("Waiting for permit"); + let permit = self.maximum_requests.acquire().await; + debug!("Got permit"); + let response = + tokio::time::timeout(Duration::from_secs(5 * 60), server_server::send_request(destination, request)) + .await + .map_err(|_| { + warn!("Timeout after 300 seconds waiting for server response of {destination}"); + Error::BadServerResponse("Timeout after 300 seconds waiting for server response") + })?; + drop(permit); + + response + } + + /// Sends a request to an appservice + /// + /// Only returns None if there is no url specified in the appservice + /// registration file + pub async fn send_appservice_request( + &self, registration: Registration, request: T, + ) -> Option> + where + T: OutgoingRequest + Debug, + { + let permit = self.maximum_requests.acquire().await; + let response = appservice_server::send_request(registration, request).await; + drop(permit); + + response + } } diff --git a/src/service/transaction_ids/data.rs b/src/service/transaction_ids/data.rs index 74855318..2aed1981 100644 --- a/src/service/transaction_ids/data.rs +++ b/src/service/transaction_ids/data.rs @@ -1,19 +1,13 @@ -use crate::Result; use ruma::{DeviceId, TransactionId, UserId}; -pub trait Data: Send + Sync { - fn add_txnid( - &self, - user_id: &UserId, - device_id: Option<&DeviceId>, - txn_id: &TransactionId, - data: &[u8], - ) -> Result<()>; +use crate::Result; - fn existing_txnid( - &self, - user_id: &UserId, - device_id: Option<&DeviceId>, - txn_id: &TransactionId, - ) -> Result>>; +pub trait Data: Send + Sync { + fn add_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], + ) -> Result<()>; + + fn existing_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, + ) -> Result>>; } diff --git a/src/service/transaction_ids/mod.rs b/src/service/transaction_ids/mod.rs index 2fa3b02e..bc55861a 100644 --- a/src/service/transaction_ids/mod.rs +++ b/src/service/transaction_ids/mod.rs @@ -1,31 +1,24 @@ mod data; pub use data::Data; - -use crate::Result; use ruma::{DeviceId, TransactionId, UserId}; +use crate::Result; + pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - pub fn add_txnid( - &self, - user_id: &UserId, - device_id: Option<&DeviceId>, - txn_id: &TransactionId, - data: &[u8], - ) -> Result<()> { - self.db.add_txnid(user_id, device_id, txn_id, data) - } + pub fn add_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8], + ) -> Result<()> { + self.db.add_txnid(user_id, device_id, txn_id, data) + } - pub fn existing_txnid( - &self, - user_id: &UserId, - device_id: Option<&DeviceId>, - txn_id: &TransactionId, - ) -> Result>> { - self.db.existing_txnid(user_id, device_id, txn_id) - } + pub fn existing_txnid( + &self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, + ) -> Result>> { + self.db.existing_txnid(user_id, device_id, txn_id) + } } diff --git a/src/service/uiaa/data.rs b/src/service/uiaa/data.rs index c64deb90..3a157068 100644 --- a/src/service/uiaa/data.rs +++ b/src/service/uiaa/data.rs @@ -1,34 +1,17 @@ -use crate::Result; use ruma::{api::client::uiaa::UiaaInfo, CanonicalJsonValue, DeviceId, UserId}; +use crate::Result; + pub trait Data: Send + Sync { - fn set_uiaa_request( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - request: &CanonicalJsonValue, - ) -> Result<()>; + fn set_uiaa_request( + &self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue, + ) -> Result<()>; - fn get_uiaa_request( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - ) -> Option; + fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option; - fn update_uiaa_session( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - uiaainfo: Option<&UiaaInfo>, - ) -> Result<()>; + fn update_uiaa_session( + &self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>, + ) -> Result<()>; - fn get_uiaa_session( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - ) -> Result; + fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result; } diff --git a/src/service/uiaa/mod.rs b/src/service/uiaa/mod.rs index 10ac3664..a2296d55 100644 --- a/src/service/uiaa/mod.rs +++ b/src/service/uiaa/mod.rs @@ -2,159 +2,142 @@ mod data; use argon2::{PasswordHash, PasswordVerifier}; pub use data::Data; - use ruma::{ - api::client::{ - error::ErrorKind, - uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier}, - }, - CanonicalJsonValue, DeviceId, UserId, + api::client::{ + error::ErrorKind, + uiaa::{AuthData, AuthType, Password, UiaaInfo, UserIdentifier}, + }, + CanonicalJsonValue, DeviceId, UserId, }; use tracing::error; use crate::{api::client_server::SESSION_ID_LENGTH, services, utils, Error, Result}; pub struct Service { - pub db: &'static dyn Data, + pub db: &'static dyn Data, } impl Service { - /// Creates a new Uiaa session. Make sure the session token is unique. - pub fn create( - &self, - user_id: &UserId, - device_id: &DeviceId, - uiaainfo: &UiaaInfo, - json_body: &CanonicalJsonValue, - ) -> Result<()> { - self.db.set_uiaa_request( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), // TODO: better session error handling (why is it optional in ruma?) - json_body, - )?; - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session should be set"), - Some(uiaainfo), - ) - } + /// Creates a new Uiaa session. Make sure the session token is unique. + pub fn create( + &self, user_id: &UserId, device_id: &DeviceId, uiaainfo: &UiaaInfo, json_body: &CanonicalJsonValue, + ) -> Result<()> { + self.db.set_uiaa_request( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), /* TODO: better session error handling (why + * is it optional in ruma?) */ + json_body, + )?; + self.db.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session should be set"), + Some(uiaainfo), + ) + } - pub fn try_auth( - &self, - user_id: &UserId, - device_id: &DeviceId, - auth: &AuthData, - uiaainfo: &UiaaInfo, - ) -> Result<(bool, UiaaInfo)> { - let mut uiaainfo = auth.session().map_or_else( - || Ok(uiaainfo.clone()), - |session| self.db.get_uiaa_session(user_id, device_id, session), - )?; + pub fn try_auth( + &self, user_id: &UserId, device_id: &DeviceId, auth: &AuthData, uiaainfo: &UiaaInfo, + ) -> Result<(bool, UiaaInfo)> { + let mut uiaainfo = auth.session().map_or_else( + || Ok(uiaainfo.clone()), + |session| self.db.get_uiaa_session(user_id, device_id, session), + )?; - if uiaainfo.session.is_none() { - uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); - } + if uiaainfo.session.is_none() { + uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH)); + } - match auth { - // Find out what the user completed - AuthData::Password(Password { - identifier, - password, - .. - }) => { - let UserIdentifier::UserIdOrLocalpart(username) = identifier else { - return Err(Error::BadRequest( - ErrorKind::Unrecognized, - "Identifier type not recognized.", - )); - }; + match auth { + // Find out what the user completed + AuthData::Password(Password { + identifier, + password, + .. + }) => { + let UserIdentifier::UserIdOrLocalpart(username) = identifier else { + return Err(Error::BadRequest(ErrorKind::Unrecognized, "Identifier type not recognized.")); + }; - let user_id = UserId::parse_with_server_name( - username.clone(), - services().globals.server_name(), - ) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; + let user_id = UserId::parse_with_server_name(username.clone(), services().globals.server_name()) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "User ID is invalid."))?; - // Check if password is correct - if let Some(hash) = services().users.password_hash(&user_id)? { - let hash_matches = services() - .globals - .argon - .verify_password( - password.as_bytes(), - &PasswordHash::new(&hash).expect("valid hash in database"), - ) - .is_ok(); + // Check if password is correct + if let Some(hash) = services().users.password_hash(&user_id)? { + let hash_matches = services() + .globals + .argon + .verify_password( + password.as_bytes(), + &PasswordHash::new(&hash).expect("valid hash in database"), + ) + .is_ok(); - if !hash_matches { - uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { - kind: ErrorKind::Forbidden, - message: "Invalid username or password.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } + if !hash_matches { + uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { + kind: ErrorKind::Forbidden, + message: "Invalid username or password.".to_owned(), + }); + return Ok((false, uiaainfo)); + } + } - // Password was correct! Let's add it to `completed` - uiaainfo.completed.push(AuthType::Password); - } - AuthData::RegistrationToken(t) => { - if Some(t.token.trim()) == services().globals.config.registration_token.as_deref() { - uiaainfo.completed.push(AuthType::RegistrationToken); - } else { - uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { - kind: ErrorKind::Forbidden, - message: "Invalid registration token.".to_owned(), - }); - return Ok((false, uiaainfo)); - } - } - AuthData::Dummy(_) => { - uiaainfo.completed.push(AuthType::Dummy); - } - k => error!("type not supported: {:?}", k), - } + // Password was correct! Let's add it to `completed` + uiaainfo.completed.push(AuthType::Password); + }, + AuthData::RegistrationToken(t) => { + if Some(t.token.trim()) == services().globals.config.registration_token.as_deref() { + uiaainfo.completed.push(AuthType::RegistrationToken); + } else { + uiaainfo.auth_error = Some(ruma::api::client::error::StandardErrorBody { + kind: ErrorKind::Forbidden, + message: "Invalid registration token.".to_owned(), + }); + return Ok((false, uiaainfo)); + } + }, + AuthData::Dummy(_) => { + uiaainfo.completed.push(AuthType::Dummy); + }, + k => error!("type not supported: {:?}", k), + } - // Check if a flow now succeeds - let mut completed = false; - 'flows: for flow in &mut uiaainfo.flows { - for stage in &flow.stages { - if !uiaainfo.completed.contains(stage) { - continue 'flows; - } - } - // We didn't break, so this flow succeeded! - completed = true; - } + // Check if a flow now succeeds + let mut completed = false; + 'flows: for flow in &mut uiaainfo.flows { + for stage in &flow.stages { + if !uiaainfo.completed.contains(stage) { + continue 'flows; + } + } + // We didn't break, so this flow succeeded! + completed = true; + } - if !completed { - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - Some(&uiaainfo), - )?; - return Ok((false, uiaainfo)); - } + if !completed { + self.db.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + Some(&uiaainfo), + )?; + return Ok((false, uiaainfo)); + } - // UIAA was successful! Remove this session and return true - self.db.update_uiaa_session( - user_id, - device_id, - uiaainfo.session.as_ref().expect("session is always set"), - None, - )?; - Ok((true, uiaainfo)) - } + // UIAA was successful! Remove this session and return true + self.db.update_uiaa_session( + user_id, + device_id, + uiaainfo.session.as_ref().expect("session is always set"), + None, + )?; + Ok((true, uiaainfo)) + } - pub fn get_uiaa_request( - &self, - user_id: &UserId, - device_id: &DeviceId, - session: &str, - ) -> Option { - self.db.get_uiaa_request(user_id, device_id, session) - } + pub fn get_uiaa_request( + &self, user_id: &UserId, device_id: &DeviceId, session: &str, + ) -> Option { + self.db.get_uiaa_request(user_id, device_id, session) + } } diff --git a/src/service/users/data.rs b/src/service/users/data.rs index ddf941e3..04074e85 100644 --- a/src/service/users/data.rs +++ b/src/service/users/data.rs @@ -1,214 +1,146 @@ -use crate::Result; -use ruma::{ - api::client::{device::Device, filter::FilterDefinition}, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::AnyToDeviceEvent, - serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, - OwnedUserId, UInt, UserId, -}; use std::collections::BTreeMap; +use ruma::{ + api::client::{device::Device, filter::FilterDefinition}, + encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, + events::AnyToDeviceEvent, + serde::Raw, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId, +}; + +use crate::Result; + pub trait Data: Send + Sync { - /// Check if a user has an account on this homeserver. - fn exists(&self, user_id: &UserId) -> Result; + /// Check if a user has an account on this homeserver. + fn exists(&self, user_id: &UserId) -> Result; - /// Check if account is deactivated - fn is_deactivated(&self, user_id: &UserId) -> Result; + /// Check if account is deactivated + fn is_deactivated(&self, user_id: &UserId) -> Result; - /// Returns the number of users registered on this server. - fn count(&self) -> Result; + /// Returns the number of users registered on this server. + fn count(&self) -> Result; - /// Find out which user an access token belongs to. - fn find_from_token(&self, token: &str) -> Result>; + /// Find out which user an access token belongs to. + fn find_from_token(&self, token: &str) -> Result>; - /// Returns an iterator over all users on this homeserver. - fn iter<'a>(&'a self) -> Box> + 'a>; + /// Returns an iterator over all users on this homeserver. + fn iter<'a>(&'a self) -> Box> + 'a>; - /// Returns a list of local users as list of usernames. - /// - /// A user account is considered `local` if the length of it's password is greater then zero. - fn list_local_users(&self) -> Result>; + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is + /// greater then zero. + fn list_local_users(&self) -> Result>; - /// Returns the password hash for the given user. - fn password_hash(&self, user_id: &UserId) -> Result>; + /// Returns the password hash for the given user. + fn password_hash(&self, user_id: &UserId) -> Result>; - /// Hash and set the user's password to the Argon2 hash - fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; + /// Hash and set the user's password to the Argon2 hash + fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()>; - /// Returns the displayname of a user on this homeserver. - fn displayname(&self, user_id: &UserId) -> Result>; + /// Returns the displayname of a user on this homeserver. + fn displayname(&self, user_id: &UserId) -> Result>; - /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. - fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()>; + /// Sets a new displayname or removes it if displayname is None. You still + /// need to nofify all rooms of this change. + fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()>; - /// Get the avatar_url of a user. - fn avatar_url(&self, user_id: &UserId) -> Result>; + /// Get the avatar_url of a user. + fn avatar_url(&self, user_id: &UserId) -> Result>; - /// Sets a new avatar_url or removes it if avatar_url is None. - fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()>; + /// Sets a new avatar_url or removes it if avatar_url is None. + fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()>; - /// Get the blurhash of a user. - fn blurhash(&self, user_id: &UserId) -> Result>; + /// Get the blurhash of a user. + fn blurhash(&self, user_id: &UserId) -> Result>; - /// Sets a new avatar_url or removes it if avatar_url is None. - fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()>; + /// Sets a new avatar_url or removes it if avatar_url is None. + fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()>; - /// Adds a new device to a user. - fn create_device( - &self, - user_id: &UserId, - device_id: &DeviceId, - token: &str, - initial_device_display_name: Option, - ) -> Result<()>; + /// Adds a new device to a user. + fn create_device( + &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, + ) -> Result<()>; - /// Removes a device from a user. - fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; + /// Removes a device from a user. + fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()>; - /// Returns an iterator over all device ids of this user. - fn all_device_ids<'a>( - &'a self, - user_id: &UserId, - ) -> Box> + 'a>; + /// Returns an iterator over all device ids of this user. + fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box> + 'a>; - /// Replaces the access token of one device. - fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; + /// Replaces the access token of one device. + fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()>; - fn add_one_time_key( - &self, - user_id: &UserId, - device_id: &DeviceId, - one_time_key_key: &DeviceKeyId, - one_time_key_value: &Raw, - ) -> Result<()>; + fn add_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw, + ) -> Result<()>; - fn last_one_time_keys_update(&self, user_id: &UserId) -> Result; + fn last_one_time_keys_update(&self, user_id: &UserId) -> Result; - fn take_one_time_key( - &self, - user_id: &UserId, - device_id: &DeviceId, - key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>>; + fn take_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, + ) -> Result)>>; - fn count_one_time_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result>; + fn count_one_time_keys(&self, user_id: &UserId, device_id: &DeviceId) + -> Result>; - fn add_device_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - device_keys: &Raw, - ) -> Result<()>; + fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()>; - fn add_cross_signing_keys( - &self, - user_id: &UserId, - master_key: &Raw, - self_signing_key: &Option>, - user_signing_key: &Option>, - notify: bool, - ) -> Result<()>; + fn add_cross_signing_keys( + &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, + user_signing_key: &Option>, notify: bool, + ) -> Result<()>; - fn sign_key( - &self, - target_id: &UserId, - key_id: &str, - signature: (String, String), - sender_id: &UserId, - ) -> Result<()>; + fn sign_key(&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId) + -> Result<()>; - fn keys_changed<'a>( - &'a self, - user_or_room_id: &str, - from: u64, - to: Option, - ) -> Box> + 'a>; + fn keys_changed<'a>( + &'a self, user_or_room_id: &str, from: u64, to: Option, + ) -> Box> + 'a>; - fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; + fn mark_device_key_update(&self, user_id: &UserId) -> Result<()>; - fn get_device_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result>>; + fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>>; - fn parse_master_key( - &self, - user_id: &UserId, - master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)>; + fn parse_master_key( + &self, user_id: &UserId, master_key: &Raw, + ) -> Result<(Vec, CrossSigningKey)>; - fn get_key( - &self, - key: &[u8], - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>>; + fn get_key( + &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>>; - fn get_master_key( - &self, - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>>; + fn get_master_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>>; - fn get_self_signing_key( - &self, - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>>; + fn get_self_signing_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>>; - fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; + fn get_user_signing_key(&self, user_id: &UserId) -> Result>>; - fn add_to_device_event( - &self, - sender: &UserId, - target_user_id: &UserId, - target_device_id: &DeviceId, - event_type: &str, - content: serde_json::Value, - ) -> Result<()>; + fn add_to_device_event( + &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, + content: serde_json::Value, + ) -> Result<()>; - fn get_to_device_events( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result>>; + fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>>; - fn remove_to_device_events( - &self, - user_id: &UserId, - device_id: &DeviceId, - until: u64, - ) -> Result<()>; + fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()>; - fn update_device_metadata( - &self, - user_id: &UserId, - device_id: &DeviceId, - device: &Device, - ) -> Result<()>; + fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()>; - /// Get device metadata. - fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) - -> Result>; + /// Get device metadata. + fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result>; - fn get_devicelist_version(&self, user_id: &UserId) -> Result>; + fn get_devicelist_version(&self, user_id: &UserId) -> Result>; - fn all_devices_metadata<'a>( - &'a self, - user_id: &UserId, - ) -> Box> + 'a>; + fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box> + 'a>; - /// Creates a new sync filter. Returns the filter id. - fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result; + /// Creates a new sync filter. Returns the filter id. + fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result; - fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result>; + fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result>; } diff --git a/src/service/users/mod.rs b/src/service/users/mod.rs index d2d0cb64..1b1622ee 100644 --- a/src/service/users/mod.rs +++ b/src/service/users/mod.rs @@ -1,630 +1,428 @@ mod data; use std::{ - collections::{BTreeMap, BTreeSet}, - mem, - sync::{Arc, Mutex}, + collections::{BTreeMap, BTreeSet}, + mem, + sync::{Arc, Mutex}, }; pub use data::Data; use ruma::{ - api::client::{ - device::Device, - error::ErrorKind, - filter::FilterDefinition, - sync::sync_events::{ - self, - v4::{ExtensionsConfig, SyncRequestList}, - }, - }, - encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, - events::AnyToDeviceEvent, - serde::Raw, - DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, - OwnedRoomId, OwnedUserId, RoomAliasId, UInt, UserId, + api::client::{ + device::Device, + error::ErrorKind, + filter::FilterDefinition, + sync::sync_events::{ + self, + v4::{ExtensionsConfig, SyncRequestList}, + }, + }, + encryption::{CrossSigningKey, DeviceKeys, OneTimeKey}, + events::AnyToDeviceEvent, + serde::Raw, + DeviceId, DeviceKeyAlgorithm, DeviceKeyId, OwnedDeviceId, OwnedDeviceKeyId, OwnedMxcUri, OwnedRoomId, OwnedUserId, + RoomAliasId, UInt, UserId, }; use crate::{services, Error, Result}; pub struct SlidingSyncCache { - lists: BTreeMap, - subscriptions: BTreeMap, - known_rooms: BTreeMap>, // For every room, the roomsince number - extensions: ExtensionsConfig, + lists: BTreeMap, + subscriptions: BTreeMap, + known_rooms: BTreeMap>, // For every room, the roomsince number + extensions: ExtensionsConfig, } -type DbConnections = - Mutex>>>; +type DbConnections = Mutex>>>; pub struct Service { - pub db: &'static dyn Data, - pub connections: DbConnections, + pub db: &'static dyn Data, + pub connections: DbConnections, } impl Service { - /// Check if a user has an account on this homeserver. - pub fn exists(&self, user_id: &UserId) -> Result { - self.db.exists(user_id) - } + /// Check if a user has an account on this homeserver. + pub fn exists(&self, user_id: &UserId) -> Result { self.db.exists(user_id) } - pub fn forget_sync_request_connection( - &self, - user_id: OwnedUserId, - device_id: OwnedDeviceId, - conn_id: String, - ) { - self.connections - .lock() - .unwrap() - .remove(&(user_id, device_id, conn_id)); - } + pub fn forget_sync_request_connection(&self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String) { + self.connections.lock().unwrap().remove(&(user_id, device_id, conn_id)); + } - pub fn update_sync_request_with_cache( - &self, - user_id: OwnedUserId, - device_id: OwnedDeviceId, - request: &mut sync_events::v4::Request, - ) -> BTreeMap> { - let Some(conn_id) = request.conn_id.clone() else { - return BTreeMap::new(); - }; + pub fn update_sync_request_with_cache( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, request: &mut sync_events::v4::Request, + ) -> BTreeMap> { + let Some(conn_id) = request.conn_id.clone() else { + return BTreeMap::new(); + }; - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); + let mut cache = self.connections.lock().unwrap(); + let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + })); + let cached = &mut cached.lock().unwrap(); + drop(cache); - for (list_id, list) in &mut request.lists { - if let Some(cached_list) = cached.lists.get(list_id) { - if list.sort.is_empty() { - list.sort = cached_list.sort.clone(); - }; - if list.room_details.required_state.is_empty() { - list.room_details.required_state = - cached_list.room_details.required_state.clone(); - }; - list.room_details.timeline_limit = list - .room_details - .timeline_limit - .or(cached_list.room_details.timeline_limit); - list.include_old_rooms = list - .include_old_rooms - .clone() - .or_else(|| cached_list.include_old_rooms.clone()); - match (&mut list.filters, cached_list.filters.clone()) { - (Some(list_filters), Some(cached_filters)) => { - list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); - if list_filters.spaces.is_empty() { - list_filters.spaces = cached_filters.spaces; - } - list_filters.is_encrypted = - list_filters.is_encrypted.or(cached_filters.is_encrypted); - list_filters.is_invite = - list_filters.is_invite.or(cached_filters.is_invite); - if list_filters.room_types.is_empty() { - list_filters.room_types = cached_filters.room_types; - } - if list_filters.not_room_types.is_empty() { - list_filters.not_room_types = cached_filters.not_room_types; - } - list_filters.room_name_like = list_filters - .room_name_like - .clone() - .or(cached_filters.room_name_like); - if list_filters.tags.is_empty() { - list_filters.tags = cached_filters.tags; - } - if list_filters.not_tags.is_empty() { - list_filters.not_tags = cached_filters.not_tags; - } - } - (_, Some(cached_filters)) => list.filters = Some(cached_filters), - (Some(list_filters), _) => list.filters = Some(list_filters.clone()), - (_, _) => {} - } - if list.bump_event_types.is_empty() { - list.bump_event_types = cached_list.bump_event_types.clone(); - }; - } - cached.lists.insert(list_id.clone(), list.clone()); - } + for (list_id, list) in &mut request.lists { + if let Some(cached_list) = cached.lists.get(list_id) { + if list.sort.is_empty() { + list.sort = cached_list.sort.clone(); + }; + if list.room_details.required_state.is_empty() { + list.room_details.required_state = cached_list.room_details.required_state.clone(); + }; + list.room_details.timeline_limit = + list.room_details.timeline_limit.or(cached_list.room_details.timeline_limit); + list.include_old_rooms = + list.include_old_rooms.clone().or_else(|| cached_list.include_old_rooms.clone()); + match (&mut list.filters, cached_list.filters.clone()) { + (Some(list_filters), Some(cached_filters)) => { + list_filters.is_dm = list_filters.is_dm.or(cached_filters.is_dm); + if list_filters.spaces.is_empty() { + list_filters.spaces = cached_filters.spaces; + } + list_filters.is_encrypted = list_filters.is_encrypted.or(cached_filters.is_encrypted); + list_filters.is_invite = list_filters.is_invite.or(cached_filters.is_invite); + if list_filters.room_types.is_empty() { + list_filters.room_types = cached_filters.room_types; + } + if list_filters.not_room_types.is_empty() { + list_filters.not_room_types = cached_filters.not_room_types; + } + list_filters.room_name_like = + list_filters.room_name_like.clone().or(cached_filters.room_name_like); + if list_filters.tags.is_empty() { + list_filters.tags = cached_filters.tags; + } + if list_filters.not_tags.is_empty() { + list_filters.not_tags = cached_filters.not_tags; + } + }, + (_, Some(cached_filters)) => list.filters = Some(cached_filters), + (Some(list_filters), _) => list.filters = Some(list_filters.clone()), + (..) => {}, + } + if list.bump_event_types.is_empty() { + list.bump_event_types = cached_list.bump_event_types.clone(); + }; + } + cached.lists.insert(list_id.clone(), list.clone()); + } - cached - .subscriptions - .extend(request.room_subscriptions.clone()); - request - .room_subscriptions - .extend(cached.subscriptions.clone()); + cached.subscriptions.extend(request.room_subscriptions.clone()); + request.room_subscriptions.extend(cached.subscriptions.clone()); - request.extensions.e2ee.enabled = request - .extensions - .e2ee - .enabled - .or(cached.extensions.e2ee.enabled); + request.extensions.e2ee.enabled = request.extensions.e2ee.enabled.or(cached.extensions.e2ee.enabled); - request.extensions.to_device.enabled = request - .extensions - .to_device - .enabled - .or(cached.extensions.to_device.enabled); + request.extensions.to_device.enabled = + request.extensions.to_device.enabled.or(cached.extensions.to_device.enabled); - request.extensions.account_data.enabled = request - .extensions - .account_data - .enabled - .or(cached.extensions.account_data.enabled); - request.extensions.account_data.lists = request - .extensions - .account_data - .lists - .clone() - .or_else(|| cached.extensions.account_data.lists.clone()); - request.extensions.account_data.rooms = request - .extensions - .account_data - .rooms - .clone() - .or_else(|| cached.extensions.account_data.rooms.clone()); + request.extensions.account_data.enabled = + request.extensions.account_data.enabled.or(cached.extensions.account_data.enabled); + request.extensions.account_data.lists = + request.extensions.account_data.lists.clone().or_else(|| cached.extensions.account_data.lists.clone()); + request.extensions.account_data.rooms = + request.extensions.account_data.rooms.clone().or_else(|| cached.extensions.account_data.rooms.clone()); - cached.extensions = request.extensions.clone(); + cached.extensions = request.extensions.clone(); - cached.known_rooms.clone() - } + cached.known_rooms.clone() + } - pub fn update_sync_subscriptions( - &self, - user_id: OwnedUserId, - device_id: OwnedDeviceId, - conn_id: String, - subscriptions: BTreeMap, - ) { - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); + pub fn update_sync_subscriptions( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, + subscriptions: BTreeMap, + ) { + let mut cache = self.connections.lock().unwrap(); + let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + })); + let cached = &mut cached.lock().unwrap(); + drop(cache); - cached.subscriptions = subscriptions; - } + cached.subscriptions = subscriptions; + } - pub fn update_sync_known_rooms( - &self, - user_id: OwnedUserId, - device_id: OwnedDeviceId, - conn_id: String, - list_id: String, - new_cached_rooms: BTreeSet, - globalsince: u64, - ) { - let mut cache = self.connections.lock().unwrap(); - let cached = Arc::clone( - cache - .entry((user_id, device_id, conn_id)) - .or_insert_with(|| { - Arc::new(Mutex::new(SlidingSyncCache { - lists: BTreeMap::new(), - subscriptions: BTreeMap::new(), - known_rooms: BTreeMap::new(), - extensions: ExtensionsConfig::default(), - })) - }), - ); - let cached = &mut cached.lock().unwrap(); - drop(cache); + pub fn update_sync_known_rooms( + &self, user_id: OwnedUserId, device_id: OwnedDeviceId, conn_id: String, list_id: String, + new_cached_rooms: BTreeSet, globalsince: u64, + ) { + let mut cache = self.connections.lock().unwrap(); + let cached = Arc::clone(cache.entry((user_id, device_id, conn_id)).or_insert_with(|| { + Arc::new(Mutex::new(SlidingSyncCache { + lists: BTreeMap::new(), + subscriptions: BTreeMap::new(), + known_rooms: BTreeMap::new(), + extensions: ExtensionsConfig::default(), + })) + })); + let cached = &mut cached.lock().unwrap(); + drop(cache); - for (roomid, lastsince) in cached - .known_rooms - .entry(list_id.clone()) - .or_default() - .iter_mut() - { - if !new_cached_rooms.contains(roomid) { - *lastsince = 0; - } - } - let list = cached.known_rooms.entry(list_id).or_default(); - for roomid in new_cached_rooms { - list.insert(roomid, globalsince); - } - } + for (roomid, lastsince) in cached.known_rooms.entry(list_id.clone()).or_default().iter_mut() { + if !new_cached_rooms.contains(roomid) { + *lastsince = 0; + } + } + let list = cached.known_rooms.entry(list_id).or_default(); + for roomid in new_cached_rooms { + list.insert(roomid, globalsince); + } + } - /// Check if account is deactivated - pub fn is_deactivated(&self, user_id: &UserId) -> Result { - self.db.is_deactivated(user_id) - } + /// Check if account is deactivated + pub fn is_deactivated(&self, user_id: &UserId) -> Result { self.db.is_deactivated(user_id) } - /// Check if a user is an admin - pub fn is_admin(&self, user_id: &UserId) -> Result { - let admin_room_alias_id = - RoomAliasId::parse(format!("#admins:{}", services().globals.server_name())) - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; - let admin_room_id = services() - .rooms - .alias - .resolve_local_alias(&admin_room_alias_id)? - .unwrap(); + /// Check if a user is an admin + pub fn is_admin(&self, user_id: &UserId) -> Result { + let admin_room_alias_id = RoomAliasId::parse(format!("#admins:{}", services().globals.server_name())) + .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid alias."))?; + let admin_room_id = services().rooms.alias.resolve_local_alias(&admin_room_alias_id)?.unwrap(); - services() - .rooms - .state_cache - .is_joined(user_id, &admin_room_id) - } + services().rooms.state_cache.is_joined(user_id, &admin_room_id) + } - /// Create a new user account on this homeserver. - pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password)?; - Ok(()) - } + /// Create a new user account on this homeserver. + pub fn create(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.db.set_password(user_id, password)?; + Ok(()) + } - /// Returns the number of users registered on this server. - pub fn count(&self) -> Result { - self.db.count() - } + /// Returns the number of users registered on this server. + pub fn count(&self) -> Result { self.db.count() } - /// Find out which user an access token belongs to. - pub fn find_from_token(&self, token: &str) -> Result> { - self.db.find_from_token(token) - } + /// Find out which user an access token belongs to. + pub fn find_from_token(&self, token: &str) -> Result> { + self.db.find_from_token(token) + } - /// Returns an iterator over all users on this homeserver. - pub fn iter(&self) -> impl Iterator> + '_ { - self.db.iter() - } + /// Returns an iterator over all users on this homeserver. + pub fn iter(&self) -> impl Iterator> + '_ { self.db.iter() } - /// Returns a list of local users as list of usernames. - /// - /// A user account is considered `local` if the length of it's password is greater then zero. - pub fn list_local_users(&self) -> Result> { - self.db.list_local_users() - } + /// Returns a list of local users as list of usernames. + /// + /// A user account is considered `local` if the length of it's password is + /// greater then zero. + pub fn list_local_users(&self) -> Result> { self.db.list_local_users() } - /// Returns the password hash for the given user. - pub fn password_hash(&self, user_id: &UserId) -> Result> { - self.db.password_hash(user_id) - } + /// Returns the password hash for the given user. + pub fn password_hash(&self, user_id: &UserId) -> Result> { self.db.password_hash(user_id) } - /// Hash and set the user's password to the Argon2 hash - pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { - self.db.set_password(user_id, password) - } + /// Hash and set the user's password to the Argon2 hash + pub fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> { + self.db.set_password(user_id, password) + } - /// Returns the displayname of a user on this homeserver. - pub fn displayname(&self, user_id: &UserId) -> Result> { - self.db.displayname(user_id) - } + /// Returns the displayname of a user on this homeserver. + pub fn displayname(&self, user_id: &UserId) -> Result> { self.db.displayname(user_id) } - /// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change. - pub async fn set_displayname( - &self, - user_id: &UserId, - displayname: Option, - ) -> Result<()> { - self.db.set_displayname(user_id, displayname) - } + /// Sets a new displayname or removes it if displayname is None. You still + /// need to nofify all rooms of this change. + pub async fn set_displayname(&self, user_id: &UserId, displayname: Option) -> Result<()> { + self.db.set_displayname(user_id, displayname) + } - /// Get the avatar_url of a user. - pub fn avatar_url(&self, user_id: &UserId) -> Result> { - self.db.avatar_url(user_id) - } + /// Get the avatar_url of a user. + pub fn avatar_url(&self, user_id: &UserId) -> Result> { self.db.avatar_url(user_id) } - /// Sets a new avatar_url or removes it if avatar_url is None. - pub async fn set_avatar_url( - &self, - user_id: &UserId, - avatar_url: Option, - ) -> Result<()> { - self.db.set_avatar_url(user_id, avatar_url) - } + /// Sets a new avatar_url or removes it if avatar_url is None. + pub async fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option) -> Result<()> { + self.db.set_avatar_url(user_id, avatar_url) + } - /// Get the blurhash of a user. - pub fn blurhash(&self, user_id: &UserId) -> Result> { - self.db.blurhash(user_id) - } + /// Get the blurhash of a user. + pub fn blurhash(&self, user_id: &UserId) -> Result> { self.db.blurhash(user_id) } - /// Sets a new avatar_url or removes it if avatar_url is None. - pub async fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { - self.db.set_blurhash(user_id, blurhash) - } + /// Sets a new avatar_url or removes it if avatar_url is None. + pub async fn set_blurhash(&self, user_id: &UserId, blurhash: Option) -> Result<()> { + self.db.set_blurhash(user_id, blurhash) + } - /// Adds a new device to a user. - pub fn create_device( - &self, - user_id: &UserId, - device_id: &DeviceId, - token: &str, - initial_device_display_name: Option, - ) -> Result<()> { - self.db - .create_device(user_id, device_id, token, initial_device_display_name) - } + /// Adds a new device to a user. + pub fn create_device( + &self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option, + ) -> Result<()> { + self.db.create_device(user_id, device_id, token, initial_device_display_name) + } - /// Removes a device from a user. - pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { - self.db.remove_device(user_id, device_id) - } + /// Removes a device from a user. + pub fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> { + self.db.remove_device(user_id, device_id) + } - /// Returns an iterator over all device ids of this user. - pub fn all_device_ids<'a>( - &'a self, - user_id: &UserId, - ) -> impl Iterator> + 'a { - self.db.all_device_ids(user_id) - } + /// Returns an iterator over all device ids of this user. + pub fn all_device_ids<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { + self.db.all_device_ids(user_id) + } - /// Replaces the access token of one device. - pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { - self.db.set_token(user_id, device_id, token) - } + /// Replaces the access token of one device. + pub fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> { + self.db.set_token(user_id, device_id, token) + } - pub fn add_one_time_key( - &self, - user_id: &UserId, - device_id: &DeviceId, - one_time_key_key: &DeviceKeyId, - one_time_key_value: &Raw, - ) -> Result<()> { - self.db - .add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) - } + pub fn add_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId, + one_time_key_value: &Raw, + ) -> Result<()> { + self.db.add_one_time_key(user_id, device_id, one_time_key_key, one_time_key_value) + } - pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { - self.db.last_one_time_keys_update(user_id) - } + pub fn last_one_time_keys_update(&self, user_id: &UserId) -> Result { + self.db.last_one_time_keys_update(user_id) + } - pub fn take_one_time_key( - &self, - user_id: &UserId, - device_id: &DeviceId, - key_algorithm: &DeviceKeyAlgorithm, - ) -> Result)>> { - self.db.take_one_time_key(user_id, device_id, key_algorithm) - } + pub fn take_one_time_key( + &self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm, + ) -> Result)>> { + self.db.take_one_time_key(user_id, device_id, key_algorithm) + } - pub fn count_one_time_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result> { - self.db.count_one_time_keys(user_id, device_id) - } + pub fn count_one_time_keys( + &self, user_id: &UserId, device_id: &DeviceId, + ) -> Result> { + self.db.count_one_time_keys(user_id, device_id) + } - pub fn add_device_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - device_keys: &Raw, - ) -> Result<()> { - self.db.add_device_keys(user_id, device_id, device_keys) - } + pub fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw) -> Result<()> { + self.db.add_device_keys(user_id, device_id, device_keys) + } - pub fn add_cross_signing_keys( - &self, - user_id: &UserId, - master_key: &Raw, - self_signing_key: &Option>, - user_signing_key: &Option>, - notify: bool, - ) -> Result<()> { - self.db.add_cross_signing_keys( - user_id, - master_key, - self_signing_key, - user_signing_key, - notify, - ) - } + pub fn add_cross_signing_keys( + &self, user_id: &UserId, master_key: &Raw, self_signing_key: &Option>, + user_signing_key: &Option>, notify: bool, + ) -> Result<()> { + self.db.add_cross_signing_keys(user_id, master_key, self_signing_key, user_signing_key, notify) + } - pub fn sign_key( - &self, - target_id: &UserId, - key_id: &str, - signature: (String, String), - sender_id: &UserId, - ) -> Result<()> { - self.db.sign_key(target_id, key_id, signature, sender_id) - } + pub fn sign_key( + &self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId, + ) -> Result<()> { + self.db.sign_key(target_id, key_id, signature, sender_id) + } - pub fn keys_changed<'a>( - &'a self, - user_or_room_id: &str, - from: u64, - to: Option, - ) -> impl Iterator> + 'a { - self.db.keys_changed(user_or_room_id, from, to) - } + pub fn keys_changed<'a>( + &'a self, user_or_room_id: &str, from: u64, to: Option, + ) -> impl Iterator> + 'a { + self.db.keys_changed(user_or_room_id, from, to) + } - pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { - self.db.mark_device_key_update(user_id) - } + pub fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> { self.db.mark_device_key_update(user_id) } - pub fn get_device_keys( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result>> { - self.db.get_device_keys(user_id, device_id) - } + pub fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { + self.db.get_device_keys(user_id, device_id) + } - pub fn parse_master_key( - &self, - user_id: &UserId, - master_key: &Raw, - ) -> Result<(Vec, CrossSigningKey)> { - self.db.parse_master_key(user_id, master_key) - } + pub fn parse_master_key( + &self, user_id: &UserId, master_key: &Raw, + ) -> Result<(Vec, CrossSigningKey)> { + self.db.parse_master_key(user_id, master_key) + } - pub fn get_key( - &self, - key: &[u8], - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_key(key, sender_user, user_id, allowed_signatures) - } + pub fn get_key( + &self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>> { + self.db.get_key(key, sender_user, user_id, allowed_signatures) + } - pub fn get_master_key( - &self, - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_master_key(sender_user, user_id, allowed_signatures) - } + pub fn get_master_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>> { + self.db.get_master_key(sender_user, user_id, allowed_signatures) + } - pub fn get_self_signing_key( - &self, - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: &dyn Fn(&UserId) -> bool, - ) -> Result>> { - self.db - .get_self_signing_key(sender_user, user_id, allowed_signatures) - } + pub fn get_self_signing_key( + &self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool, + ) -> Result>> { + self.db.get_self_signing_key(sender_user, user_id, allowed_signatures) + } - pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { - self.db.get_user_signing_key(user_id) - } + pub fn get_user_signing_key(&self, user_id: &UserId) -> Result>> { + self.db.get_user_signing_key(user_id) + } - pub fn add_to_device_event( - &self, - sender: &UserId, - target_user_id: &UserId, - target_device_id: &DeviceId, - event_type: &str, - content: serde_json::Value, - ) -> Result<()> { - self.db.add_to_device_event( - sender, - target_user_id, - target_device_id, - event_type, - content, - ) - } + pub fn add_to_device_event( + &self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str, + content: serde_json::Value, + ) -> Result<()> { + self.db.add_to_device_event(sender, target_user_id, target_device_id, event_type, content) + } - pub fn get_to_device_events( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result>> { - self.db.get_to_device_events(user_id, device_id) - } + pub fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result>> { + self.db.get_to_device_events(user_id, device_id) + } - pub fn remove_to_device_events( - &self, - user_id: &UserId, - device_id: &DeviceId, - until: u64, - ) -> Result<()> { - self.db.remove_to_device_events(user_id, device_id, until) - } + pub fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> { + self.db.remove_to_device_events(user_id, device_id, until) + } - pub fn update_device_metadata( - &self, - user_id: &UserId, - device_id: &DeviceId, - device: &Device, - ) -> Result<()> { - self.db.update_device_metadata(user_id, device_id, device) - } + pub fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> { + self.db.update_device_metadata(user_id, device_id, device) + } - /// Get device metadata. - pub fn get_device_metadata( - &self, - user_id: &UserId, - device_id: &DeviceId, - ) -> Result> { - self.db.get_device_metadata(user_id, device_id) - } + /// Get device metadata. + pub fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result> { + self.db.get_device_metadata(user_id, device_id) + } - pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { - self.db.get_devicelist_version(user_id) - } + pub fn get_devicelist_version(&self, user_id: &UserId) -> Result> { + self.db.get_devicelist_version(user_id) + } - pub fn all_devices_metadata<'a>( - &'a self, - user_id: &UserId, - ) -> impl Iterator> + 'a { - self.db.all_devices_metadata(user_id) - } + pub fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> impl Iterator> + 'a { + self.db.all_devices_metadata(user_id) + } - /// Deactivate account - pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { - // Remove all associated devices - for device_id in self.all_device_ids(user_id) { - self.remove_device(user_id, &device_id?)?; - } + /// Deactivate account + pub fn deactivate_account(&self, user_id: &UserId) -> Result<()> { + // Remove all associated devices + for device_id in self.all_device_ids(user_id) { + self.remove_device(user_id, &device_id?)?; + } - // Set the password to "" to indicate a deactivated account. Hashes will never result in an - // empty string, so the user will not be able to log in again. Systems like changing the - // password without logging in should check if the account is deactivated. - self.db.set_password(user_id, None)?; + // Set the password to "" to indicate a deactivated account. Hashes will never + // result in an empty string, so the user will not be able to log in again. + // Systems like changing the password without logging in should check if the + // account is deactivated. + self.db.set_password(user_id, None)?; - // TODO: Unhook 3PID - Ok(()) - } + // TODO: Unhook 3PID + Ok(()) + } - /// Creates a new sync filter. Returns the filter id. - pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { - self.db.create_filter(user_id, filter) - } + /// Creates a new sync filter. Returns the filter id. + pub fn create_filter(&self, user_id: &UserId, filter: &FilterDefinition) -> Result { + self.db.create_filter(user_id, filter) + } - pub fn get_filter( - &self, - user_id: &UserId, - filter_id: &str, - ) -> Result> { - self.db.get_filter(user_id, filter_id) - } + pub fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result> { + self.db.get_filter(user_id, filter_id) + } } /// Ensure that a user only sees signatures from themselves and the target user pub fn clean_signatures bool>( - cross_signing_key: &mut serde_json::Value, - sender_user: Option<&UserId>, - user_id: &UserId, - allowed_signatures: F, + cross_signing_key: &mut serde_json::Value, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: F, ) -> Result<(), Error> { - if let Some(signatures) = cross_signing_key - .get_mut("signatures") - .and_then(|v| v.as_object_mut()) - { - // Don't allocate for the full size of the current signatures, but require - // at most one resize if nothing is dropped - let new_capacity = signatures.len() / 2; - for (user, signature) in - mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) - { - let sid = <&UserId>::try_from(user.as_str()) - .map_err(|_| Error::bad_database("Invalid user ID in database."))?; - if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) { - signatures.insert(user, signature); - } - } - } + if let Some(signatures) = cross_signing_key.get_mut("signatures").and_then(|v| v.as_object_mut()) { + // Don't allocate for the full size of the current signatures, but require + // at most one resize if nothing is dropped + let new_capacity = signatures.len() / 2; + for (user, signature) in mem::replace(signatures, serde_json::Map::with_capacity(new_capacity)) { + let sid = + <&UserId>::try_from(user.as_str()).map_err(|_| Error::bad_database("Invalid user ID in database."))?; + if sender_user == Some(user_id) || sid == user_id || allowed_signatures(sid) { + signatures.insert(user, signature); + } + } + } - Ok(()) + Ok(()) } diff --git a/src/utils/error.rs b/src/utils/error.rs index ba6caf5e..60209860 100644 --- a/src/utils/error.rs +++ b/src/utils/error.rs @@ -2,11 +2,11 @@ use std::convert::Infallible; use http::StatusCode; use ruma::{ - api::client::{ - error::{Error as RumaError, ErrorBody, ErrorKind}, - uiaa::{UiaaInfo, UiaaResponse}, - }, - OwnedServerName, + api::client::{ + error::{Error as RumaError, ErrorBody, ErrorKind}, + uiaa::{UiaaInfo, UiaaResponse}, + }, + OwnedServerName, }; use thiserror::Error; use tracing::{error, info}; @@ -17,149 +17,165 @@ pub type Result = std::result::Result; #[derive(Error, Debug)] pub enum Error { - #[cfg(feature = "sqlite")] - #[error("There was a problem with the connection to the sqlite database: {source}")] - SqliteError { - #[from] - source: rusqlite::Error, - }, - #[cfg(feature = "rocksdb")] - #[error("There was a problem with the connection to the rocksdb database: {source}")] - RocksDbError { - #[from] - source: rocksdb::Error, - }, - #[error("Could not generate an image.")] - ImageError { - #[from] - source: image::error::ImageError, - }, - #[error("Could not connect to server: {source}")] - ReqwestError { - #[from] - source: reqwest::Error, - }, - #[error("{0}")] - FederationError(OwnedServerName, RumaError), - #[error("Could not do this io: {source}")] - IoError { - #[from] - source: std::io::Error, - }, - #[error("{0}")] - BadServerResponse(&'static str), - #[error("{0}")] - BadConfig(&'static str), - #[error("{0}")] - /// Don't create this directly. Use Error::bad_database instead. - BadDatabase(&'static str), - #[error("uiaa")] - Uiaa(UiaaInfo), - #[error("{0}: {1}")] - BadRequest(ErrorKind, &'static str), - #[error("{0}")] - Conflict(&'static str), // This is only needed for when a room alias already exists - #[cfg(feature = "conduit_bin")] - #[error("{0}")] - ExtensionError(#[from] axum::extract::rejection::ExtensionRejection), - #[cfg(feature = "conduit_bin")] - #[error("{0}")] - PathError(#[from] axum::extract::rejection::PathRejection), - #[error("from {0}: {1}")] - RedactionError(OwnedServerName, ruma::canonical_json::RedactionError), - #[error("{0} in {1}")] - InconsistentRoomState(&'static str, ruma::OwnedRoomId), + #[cfg(feature = "sqlite")] + #[error("There was a problem with the connection to the sqlite database: {source}")] + SqliteError { + #[from] + source: rusqlite::Error, + }, + #[cfg(feature = "rocksdb")] + #[error("There was a problem with the connection to the rocksdb database: {source}")] + RocksDbError { + #[from] + source: rocksdb::Error, + }, + #[error("Could not generate an image.")] + ImageError { + #[from] + source: image::error::ImageError, + }, + #[error("Could not connect to server: {source}")] + ReqwestError { + #[from] + source: reqwest::Error, + }, + #[error("{0}")] + FederationError(OwnedServerName, RumaError), + #[error("Could not do this io: {source}")] + IoError { + #[from] + source: std::io::Error, + }, + #[error("{0}")] + BadServerResponse(&'static str), + #[error("{0}")] + BadConfig(&'static str), + #[error("{0}")] + /// Don't create this directly. Use Error::bad_database instead. + BadDatabase(&'static str), + #[error("uiaa")] + Uiaa(UiaaInfo), + #[error("{0}: {1}")] + BadRequest(ErrorKind, &'static str), + #[error("{0}")] + Conflict(&'static str), // This is only needed for when a room alias already exists + #[cfg(feature = "conduit_bin")] + #[error("{0}")] + ExtensionError(#[from] axum::extract::rejection::ExtensionRejection), + #[cfg(feature = "conduit_bin")] + #[error("{0}")] + PathError(#[from] axum::extract::rejection::PathRejection), + #[error("from {0}: {1}")] + RedactionError(OwnedServerName, ruma::canonical_json::RedactionError), + #[error("{0} in {1}")] + InconsistentRoomState(&'static str, ruma::OwnedRoomId), } impl Error { - pub fn bad_database(message: &'static str) -> Self { - error!("BadDatabase: {}", message); - Self::BadDatabase(message) - } + pub fn bad_database(message: &'static str) -> Self { + error!("BadDatabase: {}", message); + Self::BadDatabase(message) + } - pub fn bad_config(message: &'static str) -> Self { - error!("BadConfig: {}", message); - Self::BadConfig(message) - } + pub fn bad_config(message: &'static str) -> Self { + error!("BadConfig: {}", message); + Self::BadConfig(message) + } } impl Error { - pub fn to_response(&self) -> RumaResponse { - if let Self::Uiaa(uiaainfo) = self { - return RumaResponse(UiaaResponse::AuthResponse(uiaainfo.clone())); - } + pub fn to_response(&self) -> RumaResponse { + if let Self::Uiaa(uiaainfo) = self { + return RumaResponse(UiaaResponse::AuthResponse(uiaainfo.clone())); + } - if let Self::FederationError(origin, error) = self { - let mut error = error.clone(); - error.body = ErrorBody::Standard { - kind: Unknown, - message: format!("Answer from {origin}: {error}"), - }; - return RumaResponse(UiaaResponse::MatrixError(error)); - } + if let Self::FederationError(origin, error) = self { + let mut error = error.clone(); + error.body = ErrorBody::Standard { + kind: Unknown, + message: format!("Answer from {origin}: {error}"), + }; + return RumaResponse(UiaaResponse::MatrixError(error)); + } - let message = format!("{self}"); + let message = format!("{self}"); - use ErrorKind::{ - Forbidden, GuestAccessForbidden, LimitExceeded, MissingToken, NotFound, - ThreepidAuthFailed, ThreepidDenied, TooLarge, Unauthorized, Unknown, UnknownToken, - Unrecognized, UserDeactivated, WrongRoomKeysVersion, - }; - let (kind, status_code) = match self { - Self::BadRequest(kind, _) => ( - kind.clone(), - match kind { - WrongRoomKeysVersion { .. } - | Forbidden - | GuestAccessForbidden - | ThreepidAuthFailed - | ThreepidDenied => StatusCode::FORBIDDEN, - Unauthorized | UnknownToken { .. } | MissingToken => StatusCode::UNAUTHORIZED, - NotFound | Unrecognized => StatusCode::NOT_FOUND, - LimitExceeded { .. } => StatusCode::TOO_MANY_REQUESTS, - UserDeactivated => StatusCode::FORBIDDEN, - TooLarge => StatusCode::PAYLOAD_TOO_LARGE, - _ => StatusCode::BAD_REQUEST, - }, - ), - Self::Conflict(_) => (Unknown, StatusCode::CONFLICT), - _ => (Unknown, StatusCode::INTERNAL_SERVER_ERROR), - }; + use ErrorKind::{ + Forbidden, GuestAccessForbidden, LimitExceeded, MissingToken, NotFound, ThreepidAuthFailed, ThreepidDenied, + TooLarge, Unauthorized, Unknown, UnknownToken, Unrecognized, UserDeactivated, WrongRoomKeysVersion, + }; + let (kind, status_code) = match self { + Self::BadRequest(kind, _) => ( + kind.clone(), + match kind { + WrongRoomKeysVersion { + .. + } + | Forbidden + | GuestAccessForbidden + | ThreepidAuthFailed + | ThreepidDenied => StatusCode::FORBIDDEN, + Unauthorized + | UnknownToken { + .. + } + | MissingToken => StatusCode::UNAUTHORIZED, + NotFound | Unrecognized => StatusCode::NOT_FOUND, + LimitExceeded { + .. + } => StatusCode::TOO_MANY_REQUESTS, + UserDeactivated => StatusCode::FORBIDDEN, + TooLarge => StatusCode::PAYLOAD_TOO_LARGE, + _ => StatusCode::BAD_REQUEST, + }, + ), + Self::Conflict(_) => (Unknown, StatusCode::CONFLICT), + _ => (Unknown, StatusCode::INTERNAL_SERVER_ERROR), + }; - info!("Returning an error: {}: {}", status_code, message); + info!("Returning an error: {}: {}", status_code, message); - RumaResponse(UiaaResponse::MatrixError(RumaError { - body: ErrorBody::Standard { kind, message }, - status_code, - })) - } + RumaResponse(UiaaResponse::MatrixError(RumaError { + body: ErrorBody::Standard { + kind, + message, + }, + status_code, + })) + } - /// Sanitizes public-facing errors that can leak sensitive information. - pub fn sanitized_error(&self) -> String { - let db_error = String::from("Database or I/O error occurred."); + /// Sanitizes public-facing errors that can leak sensitive information. + pub fn sanitized_error(&self) -> String { + let db_error = String::from("Database or I/O error occurred."); - match self { - #[cfg(feature = "sqlite")] - Self::SqliteError { .. } => db_error, - #[cfg(feature = "rocksdb")] - Self::RocksDbError { .. } => db_error, - Self::IoError { .. } => db_error, - Self::BadConfig { .. } => db_error, - Self::BadDatabase { .. } => db_error, - _ => self.to_string(), - } - } + match self { + #[cfg(feature = "sqlite")] + Self::SqliteError { + .. + } => db_error, + #[cfg(feature = "rocksdb")] + Self::RocksDbError { + .. + } => db_error, + Self::IoError { + .. + } => db_error, + Self::BadConfig { + .. + } => db_error, + Self::BadDatabase { + .. + } => db_error, + _ => self.to_string(), + } + } } impl From for Error { - fn from(i: Infallible) -> Self { - match i {} - } + fn from(i: Infallible) -> Self { match i {} } } #[cfg(feature = "conduit_bin")] impl axum::response::IntoResponse for Error { - fn into_response(self) -> axum::response::Response { - self.to_response().into_response() - } + fn into_response(self) -> axum::response::Response { self.to_response().into_response() } } diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 37ebc4f8..b6edede1 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,158 +1,136 @@ pub(crate) mod error; -use crate::{services, Error, Result}; +use std::{ + cmp::Ordering, + fmt, + str::FromStr, + time::{SystemTime, UNIX_EPOCH}, +}; + use argon2::{password_hash::SaltString, PasswordHasher}; use rand::prelude::*; use ring::digest; -use ruma::{ - canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId, -}; -use std::{ - cmp::Ordering, - fmt, - str::FromStr, - time::{SystemTime, UNIX_EPOCH}, -}; +use ruma::{canonical_json::try_from_json_map, CanonicalJsonError, CanonicalJsonObject, OwnedUserId}; + +use crate::{services, Error, Result}; pub(crate) fn millis_since_unix_epoch() -> u64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("time is valid") - .as_millis() as u64 + SystemTime::now().duration_since(UNIX_EPOCH).expect("time is valid").as_millis() as u64 } pub(crate) fn increment(old: Option<&[u8]>) -> Option> { - let number = match old.map(std::convert::TryInto::try_into) { - Some(Ok(bytes)) => { - let number = u64::from_be_bytes(bytes); - number + 1 - } - _ => 1, // Start at one. since 0 should return the first event in the db - }; + let number = match old.map(std::convert::TryInto::try_into) { + Some(Ok(bytes)) => { + let number = u64::from_be_bytes(bytes); + number + 1 + }, + _ => 1, // Start at one. since 0 should return the first event in the db + }; - Some(number.to_be_bytes().to_vec()) + Some(number.to_be_bytes().to_vec()) } pub fn generate_keypair() -> Vec { - let mut value = random_string(8).as_bytes().to_vec(); - value.push(0xff); - value.extend_from_slice( - &ruma::signatures::Ed25519KeyPair::generate() - .expect("Ed25519KeyPair generation always works (?)"), - ); - value + let mut value = random_string(8).as_bytes().to_vec(); + value.push(0xFF); + value.extend_from_slice( + &ruma::signatures::Ed25519KeyPair::generate().expect("Ed25519KeyPair generation always works (?)"), + ); + value } /// Parses the bytes into an u64. pub fn u64_from_bytes(bytes: &[u8]) -> Result { - let array: [u8; 8] = bytes.try_into()?; - Ok(u64::from_be_bytes(array)) + let array: [u8; 8] = bytes.try_into()?; + Ok(u64::from_be_bytes(array)) } /// Parses the bytes into a string. pub fn string_from_bytes(bytes: &[u8]) -> Result { - String::from_utf8(bytes.to_vec()) + String::from_utf8(bytes.to_vec()) } /// Parses a OwnedUserId from bytes. pub fn user_id_from_bytes(bytes: &[u8]) -> Result { - OwnedUserId::try_from( - string_from_bytes(bytes) - .map_err(|_| Error::bad_database("Failed to parse string from bytes"))?, - ) - .map_err(|_| Error::bad_database("Failed to parse user id from bytes")) + OwnedUserId::try_from( + string_from_bytes(bytes).map_err(|_| Error::bad_database("Failed to parse string from bytes"))?, + ) + .map_err(|_| Error::bad_database("Failed to parse user id from bytes")) } pub fn random_string(length: usize) -> String { - thread_rng() - .sample_iter(&rand::distributions::Alphanumeric) - .take(length) - .map(char::from) - .collect() + thread_rng().sample_iter(&rand::distributions::Alphanumeric).take(length).map(char::from).collect() } /// Calculate a new hash for the given password pub fn calculate_password_hash(password: &str) -> Result { - let salt = SaltString::generate(thread_rng()); - services() - .globals - .argon - .hash_password(password.as_bytes(), &salt) - .map(|it| it.to_string()) + let salt = SaltString::generate(thread_rng()); + services().globals.argon.hash_password(password.as_bytes(), &salt).map(|it| it.to_string()) } #[tracing::instrument(skip(keys))] pub fn calculate_hash(keys: &[&[u8]]) -> Vec { - // We only hash the pdu's event ids, not the whole pdu - let bytes = keys.join(&0xff); - let hash = digest::digest(&digest::SHA256, &bytes); - hash.as_ref().to_owned() + // We only hash the pdu's event ids, not the whole pdu + let bytes = keys.join(&0xFF); + let hash = digest::digest(&digest::SHA256, &bytes); + hash.as_ref().to_owned() } pub(crate) fn common_elements( - mut iterators: impl Iterator>>, - check_order: impl Fn(&[u8], &[u8]) -> Ordering, + mut iterators: impl Iterator>>, check_order: impl Fn(&[u8], &[u8]) -> Ordering, ) -> Option>> { - let first_iterator = iterators.next()?; - let mut other_iterators = iterators - .map(std::iter::Iterator::peekable) - .collect::>(); + let first_iterator = iterators.next()?; + let mut other_iterators = iterators.map(std::iter::Iterator::peekable).collect::>(); - Some(first_iterator.filter(move |target| { - other_iterators.iter_mut().all(|it| { - while let Some(element) = it.peek() { - match check_order(element, target) { - Ordering::Greater => return false, // We went too far - Ordering::Equal => return true, // Element is in both iters - Ordering::Less => { - // Keep searching - it.next(); - } - } - } - false - }) - })) + Some(first_iterator.filter(move |target| { + other_iterators.iter_mut().all(|it| { + while let Some(element) = it.peek() { + match check_order(element, target) { + Ordering::Greater => return false, // We went too far + Ordering::Equal => return true, // Element is in both iters + Ordering::Less => { + // Keep searching + it.next(); + }, + } + } + false + }) + })) } -/// Fallible conversion from any value that implements `Serialize` to a `CanonicalJsonObject`. +/// Fallible conversion from any value that implements `Serialize` to a +/// `CanonicalJsonObject`. /// /// `value` must serialize to an `serde_json::Value::Object`. -pub(crate) fn to_canonical_object( - value: T, -) -> Result { - use serde::ser::Error; +pub(crate) fn to_canonical_object(value: T) -> Result { + use serde::ser::Error; - match serde_json::to_value(value).map_err(CanonicalJsonError::SerDe)? { - serde_json::Value::Object(map) => try_from_json_map(map), - _ => Err(CanonicalJsonError::SerDe(serde_json::Error::custom( - "Value must be an object", - ))), - } + match serde_json::to_value(value).map_err(CanonicalJsonError::SerDe)? { + serde_json::Value::Object(map) => try_from_json_map(map), + _ => Err(CanonicalJsonError::SerDe(serde_json::Error::custom("Value must be an object"))), + } } -pub(crate) fn deserialize_from_str< - 'de, - D: serde::de::Deserializer<'de>, - T: FromStr, - E: std::fmt::Display, ->( - deserializer: D, +pub(crate) fn deserialize_from_str<'de, D: serde::de::Deserializer<'de>, T: FromStr, E: std::fmt::Display>( + deserializer: D, ) -> Result { - struct Visitor, E>(std::marker::PhantomData); - impl, Err: std::fmt::Display> serde::de::Visitor<'_> for Visitor { - type Value = T; - fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(formatter, "a parsable string") - } - fn visit_str(self, v: &str) -> Result - where - E: serde::de::Error, - { - v.parse().map_err(serde::de::Error::custom) - } - } - deserializer.deserialize_str(Visitor(std::marker::PhantomData)) + struct Visitor, E>(std::marker::PhantomData); + impl, Err: std::fmt::Display> serde::de::Visitor<'_> for Visitor { + type Value = T; + + fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + write!(formatter, "a parsable string") + } + + fn visit_str(self, v: &str) -> Result + where + E: serde::de::Error, + { + v.parse().map_err(serde::de::Error::custom) + } + } + deserializer.deserialize_str(Visitor(std::marker::PhantomData)) } // Copied from librustdoc: @@ -163,31 +141,31 @@ pub(crate) fn deserialize_from_str< pub(crate) struct HtmlEscape<'a>(pub(crate) &'a str); impl fmt::Display for HtmlEscape<'_> { - fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { - // Because the internet is always right, turns out there's not that many - // characters to escape: http://stackoverflow.com/questions/7381974 - let HtmlEscape(s) = *self; - let pile_o_bits = s; - let mut last = 0; - for (i, ch) in s.char_indices() { - let s = match ch { - '>' => ">", - '<' => "<", - '&' => "&", - '\'' => "'", - '"' => """, - _ => continue, - }; - fmt.write_str(&pile_o_bits[last..i])?; - fmt.write_str(s)?; - // NOTE: we only expect single byte characters here - which is fine as long as we - // only match single byte characters - last = i + 1; - } + fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result { + // Because the internet is always right, turns out there's not that many + // characters to escape: http://stackoverflow.com/questions/7381974 + let HtmlEscape(s) = *self; + let pile_o_bits = s; + let mut last = 0; + for (i, ch) in s.char_indices() { + let s = match ch { + '>' => ">", + '<' => "<", + '&' => "&", + '\'' => "'", + '"' => """, + _ => continue, + }; + fmt.write_str(&pile_o_bits[last..i])?; + fmt.write_str(s)?; + // NOTE: we only expect single byte characters here - which is fine as long as + // we only match single byte characters + last = i + 1; + } - if last < s.len() { - fmt.write_str(&pile_o_bits[last..])?; - } - Ok(()) - } + if last < s.len() { + fmt.write_str(&pile_o_bits[last..])?; + } + Ok(()) + } }