add rustfmt.toml, format entire codebase
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
parent
9fd521f041
commit
f419c64aca
144 changed files with 25573 additions and 31053 deletions
27
rustfmt.toml
Normal file
27
rustfmt.toml
Normal file
|
@ -0,0 +1,27 @@
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
condense_wildcard_suffixes = true
|
||||||
|
format_code_in_doc_comments = true
|
||||||
|
format_macro_bodies = true
|
||||||
|
format_macro_matchers = true
|
||||||
|
format_strings = true
|
||||||
|
hex_literal_case = "Upper"
|
||||||
|
max_width = 120
|
||||||
|
tab_spaces = 4
|
||||||
|
array_width = 80
|
||||||
|
comment_width = 80
|
||||||
|
wrap_comments = true
|
||||||
|
fn_params_layout = "Compressed"
|
||||||
|
fn_call_width = 80
|
||||||
|
fn_single_line = true
|
||||||
|
hard_tabs = true
|
||||||
|
match_block_trailing_comma = true
|
||||||
|
imports_granularity = "Crate"
|
||||||
|
normalize_comments = false
|
||||||
|
reorder_impl_items = true
|
||||||
|
reorder_imports = true
|
||||||
|
group_imports = "StdExternalCrate"
|
||||||
|
newline_style = "Unix"
|
||||||
|
use_field_init_shorthand = true
|
||||||
|
use_small_heuristics = "Off"
|
||||||
|
use_try_shorthand = true
|
|
@ -1,18 +1,16 @@
|
||||||
use crate::{services, utils, Error, Result};
|
|
||||||
use bytes::BytesMut;
|
|
||||||
use ruma::api::{
|
|
||||||
appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
|
||||||
};
|
|
||||||
use std::{fmt::Debug, mem, time::Duration};
|
use std::{fmt::Debug, mem, time::Duration};
|
||||||
|
|
||||||
|
use bytes::BytesMut;
|
||||||
|
use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
||||||
|
use crate::{services, utils, Error, Result};
|
||||||
|
|
||||||
/// Sends a request to an appservice
|
/// Sends a request to an appservice
|
||||||
///
|
///
|
||||||
/// Only returns None if there is no url specified in the appservice registration file
|
/// Only returns None if there is no url specified in the appservice
|
||||||
pub(crate) async fn send_request<T>(
|
/// registration file
|
||||||
registration: Registration,
|
pub(crate) async fn send_request<T>(registration: Registration, request: T) -> Option<Result<T::IncomingResponse>>
|
||||||
request: T,
|
|
||||||
) -> Option<Result<T::IncomingResponse>>
|
|
||||||
where
|
where
|
||||||
T: OutgoingRequest + Debug,
|
T: OutgoingRequest + Debug,
|
||||||
{
|
{
|
||||||
|
@ -40,25 +38,16 @@ where
|
||||||
"?"
|
"?"
|
||||||
};
|
};
|
||||||
|
|
||||||
parts.path_and_query = Some(
|
parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap());
|
||||||
(old_path_and_query + symbol + "access_token=" + hs_token)
|
|
||||||
.parse()
|
|
||||||
.unwrap(),
|
|
||||||
);
|
|
||||||
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
|
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
|
||||||
|
|
||||||
let mut reqwest_request = reqwest::Request::try_from(http_request)
|
let mut reqwest_request =
|
||||||
.expect("all http requests are valid reqwest requests");
|
reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests");
|
||||||
|
|
||||||
*reqwest_request.timeout_mut() = Some(Duration::from_secs(120));
|
*reqwest_request.timeout_mut() = Some(Duration::from_secs(120));
|
||||||
|
|
||||||
let url = reqwest_request.url().clone();
|
let url = reqwest_request.url().clone();
|
||||||
let mut response = match services()
|
let mut response = match services().globals.default_client().execute(reqwest_request).await {
|
||||||
.globals
|
|
||||||
.default_client()
|
|
||||||
.execute(reqwest_request)
|
|
||||||
.await
|
|
||||||
{
|
|
||||||
Ok(r) => r,
|
Ok(r) => r,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!(
|
warn!(
|
||||||
|
@ -66,19 +55,15 @@ where
|
||||||
registration.id, destination, e
|
registration.id, destination, e
|
||||||
);
|
);
|
||||||
return Some(Err(e.into()));
|
return Some(Err(e.into()));
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// reqwest::Response -> http::Response conversion
|
// reqwest::Response -> http::Response conversion
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let mut http_response_builder = http::Response::builder()
|
let mut http_response_builder = http::Response::builder().status(status).version(response.version());
|
||||||
.status(status)
|
|
||||||
.version(response.version());
|
|
||||||
mem::swap(
|
mem::swap(
|
||||||
response.headers_mut(),
|
response.headers_mut(),
|
||||||
http_response_builder
|
http_response_builder.headers_mut().expect("http::response::Builder is usable"),
|
||||||
.headers_mut()
|
|
||||||
.expect("http::response::Builder is usable"),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||||
|
@ -97,15 +82,10 @@ where
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = T::IncomingResponse::try_from_http_response(
|
let response = T::IncomingResponse::try_from_http_response(
|
||||||
http_response_builder
|
http_response_builder.body(body).expect("reqwest body is valid http body"),
|
||||||
.body(body)
|
|
||||||
.expect("reqwest body is valid http body"),
|
|
||||||
);
|
);
|
||||||
Some(response.map_err(|_| {
|
Some(response.map_err(|_| {
|
||||||
warn!(
|
warn!("Appservice returned invalid response bytes {}\n{}", destination, url);
|
||||||
"Appservice returned invalid response bytes {}\n{}",
|
|
||||||
destination, url
|
|
||||||
);
|
|
||||||
Error::BadServerResponse("Server returned bad response.")
|
Error::BadServerResponse("Server returned bad response.")
|
||||||
}))
|
}))
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -1,11 +1,10 @@
|
||||||
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
|
use register::RegistrationKind;
|
||||||
use crate::{api::client_server, services, utils, Error, Result, Ruma};
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{
|
api::client::{
|
||||||
account::{
|
account::{
|
||||||
change_password, deactivate, get_3pids, get_username_availability, register,
|
change_password, deactivate, get_3pids, get_username_availability, register,
|
||||||
request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn,
|
request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn, whoami,
|
||||||
whoami, ThirdPartyIdRemovalStatus,
|
ThirdPartyIdRemovalStatus,
|
||||||
},
|
},
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||||
|
@ -15,7 +14,8 @@ use ruma::{
|
||||||
};
|
};
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
use register::RegistrationKind;
|
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
|
||||||
|
use crate::{api::client_server, services, utils, Error, Result, Ruma};
|
||||||
|
|
||||||
const RANDOM_USER_ID_LENGTH: usize = 10;
|
const RANDOM_USER_ID_LENGTH: usize = 10;
|
||||||
|
|
||||||
|
@ -28,130 +28,109 @@ const RANDOM_USER_ID_LENGTH: usize = 10;
|
||||||
/// - The server name of the user id matches this server
|
/// - The server name of the user id matches this server
|
||||||
/// - No user or appservice on this server already claimed this username
|
/// - No user or appservice on this server already claimed this username
|
||||||
///
|
///
|
||||||
/// Note: This will not reserve the username, so the username might become invalid when trying to register
|
/// Note: This will not reserve the username, so the username might become
|
||||||
|
/// invalid when trying to register
|
||||||
pub async fn get_register_available_route(
|
pub async fn get_register_available_route(
|
||||||
body: Ruma<get_username_availability::v3::Request>,
|
body: Ruma<get_username_availability::v3::Request>,
|
||||||
) -> Result<get_username_availability::v3::Response> {
|
) -> Result<get_username_availability::v3::Response> {
|
||||||
// Validate user id
|
// Validate user id
|
||||||
let user_id = UserId::parse_with_server_name(
|
let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name())
|
||||||
body.username.to_lowercase(),
|
|
||||||
services().globals.server_name(),
|
|
||||||
)
|
|
||||||
.ok()
|
.ok()
|
||||||
.filter(|user_id| {
|
.filter(|user_id| !user_id.is_historical() && user_id.server_name() == services().globals.server_name())
|
||||||
!user_id.is_historical() && user_id.server_name() == services().globals.server_name()
|
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
|
||||||
})
|
|
||||||
.ok_or(Error::BadRequest(
|
|
||||||
ErrorKind::InvalidUsername,
|
|
||||||
"Username is invalid.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
// Check if username is creative enough
|
// Check if username is creative enough
|
||||||
if services().users.exists(&user_id)? {
|
if services().users.exists(&user_id)? {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
|
||||||
ErrorKind::UserInUse,
|
|
||||||
"Desired user ID is already taken.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if services()
|
if services().globals.forbidden_usernames().is_match(user_id.localpart()) {
|
||||||
.globals
|
return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden."));
|
||||||
.forbidden_usernames()
|
|
||||||
.is_match(user_id.localpart())
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Unknown,
|
|
||||||
"Username is forbidden.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO add check for appservice namespaces
|
// TODO add check for appservice namespaces
|
||||||
|
|
||||||
// If no if check is true we have an username that's available to be used.
|
// If no if check is true we have an username that's available to be used.
|
||||||
Ok(get_username_availability::v3::Response { available: true })
|
Ok(get_username_availability::v3::Response {
|
||||||
|
available: true,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `POST /_matrix/client/v3/register`
|
/// # `POST /_matrix/client/v3/register`
|
||||||
///
|
///
|
||||||
/// Register an account on this homeserver.
|
/// Register an account on this homeserver.
|
||||||
///
|
///
|
||||||
/// You can use [`GET /_matrix/client/v3/register/available`](fn.get_register_available_route.html)
|
/// You can use [`GET
|
||||||
/// to check if the user id is valid and available.
|
/// /_matrix/client/v3/register/available`](fn.get_register_available_route.
|
||||||
|
/// html) to check if the user id is valid and available.
|
||||||
///
|
///
|
||||||
/// - Only works if registration is enabled
|
/// - Only works if registration is enabled
|
||||||
/// - If type is guest: ignores all parameters except initial_device_display_name
|
/// - If type is guest: ignores all parameters except
|
||||||
|
/// initial_device_display_name
|
||||||
/// - If sender is not appservice: Requires UIAA (but we only use a dummy stage)
|
/// - If sender is not appservice: Requires UIAA (but we only use a dummy stage)
|
||||||
/// - If type is not guest and no username is given: Always fails after UIAA check
|
/// - If type is not guest and no username is given: Always fails after UIAA
|
||||||
|
/// check
|
||||||
/// - Creates a new account and populates it with default account data
|
/// - Creates a new account and populates it with default account data
|
||||||
/// - If `inhibit_login` is false: Creates a device and returns device id and access_token
|
/// - If `inhibit_login` is false: Creates a device and returns device id and
|
||||||
|
/// access_token
|
||||||
pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<register::v3::Response> {
|
pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<register::v3::Response> {
|
||||||
if !services().globals.allow_registration() && !body.from_appservice {
|
if !services().globals.allow_registration() && !body.from_appservice {
|
||||||
info!("Registration disabled and request not from known appservice, rejecting registration attempt for username {:?}", body.username);
|
info!(
|
||||||
return Err(Error::BadRequest(
|
"Registration disabled and request not from known appservice, rejecting registration attempt for username \
|
||||||
ErrorKind::Forbidden,
|
{:?}",
|
||||||
"Registration has been disabled.",
|
body.username
|
||||||
));
|
);
|
||||||
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Registration has been disabled."));
|
||||||
}
|
}
|
||||||
|
|
||||||
let is_guest = body.kind == RegistrationKind::Guest;
|
let is_guest = body.kind == RegistrationKind::Guest;
|
||||||
|
|
||||||
if is_guest
|
if is_guest
|
||||||
&& (!services().globals.allow_guest_registration()
|
&& (!services().globals.allow_guest_registration()
|
||||||
|| (services().globals.allow_registration()
|
|| (services().globals.allow_registration() && services().globals.config.registration_token.is_some()))
|
||||||
&& services().globals.config.registration_token.is_some()))
|
|
||||||
{
|
{
|
||||||
info!("Guest registration disabled / registration enabled with token configured, rejecting guest registration, initial device name: {:?}", body.initial_device_display_name);
|
info!(
|
||||||
|
"Guest registration disabled / registration enabled with token configured, rejecting guest registration, \
|
||||||
|
initial device name: {:?}",
|
||||||
|
body.initial_device_display_name
|
||||||
|
);
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::GuestAccessForbidden,
|
ErrorKind::GuestAccessForbidden,
|
||||||
"Guest registration is disabled.",
|
"Guest registration is disabled.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
// forbid guests from registering if there is not a real admin user yet. give generic user error.
|
// forbid guests from registering if there is not a real admin user yet. give
|
||||||
|
// generic user error.
|
||||||
if is_guest && services().users.count()? < 2 {
|
if is_guest && services().users.count()? < 2 {
|
||||||
warn!("Guest account attempted to register before a real admin user has been registered, rejecting registration. Guest's initial device name: {:?}", body.initial_device_display_name);
|
warn!(
|
||||||
return Err(Error::BadRequest(
|
"Guest account attempted to register before a real admin user has been registered, rejecting \
|
||||||
ErrorKind::Forbidden,
|
registration. Guest's initial device name: {:?}",
|
||||||
"Registration temporarily disabled.",
|
body.initial_device_display_name
|
||||||
));
|
);
|
||||||
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Registration temporarily disabled."));
|
||||||
}
|
}
|
||||||
|
|
||||||
let user_id = match (&body.username, is_guest) {
|
let user_id = match (&body.username, is_guest) {
|
||||||
(Some(username), false) => {
|
(Some(username), false) => {
|
||||||
let proposed_user_id = UserId::parse_with_server_name(
|
let proposed_user_id =
|
||||||
username.to_lowercase(),
|
UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name())
|
||||||
services().globals.server_name(),
|
|
||||||
)
|
|
||||||
.ok()
|
.ok()
|
||||||
.filter(|user_id| {
|
.filter(|user_id| {
|
||||||
!user_id.is_historical()
|
!user_id.is_historical() && user_id.server_name() == services().globals.server_name()
|
||||||
&& user_id.server_name() == services().globals.server_name()
|
|
||||||
})
|
})
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
|
||||||
ErrorKind::InvalidUsername,
|
|
||||||
"Username is invalid.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
if services().users.exists(&proposed_user_id)? {
|
if services().users.exists(&proposed_user_id)? {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
|
||||||
ErrorKind::UserInUse,
|
|
||||||
"Desired user ID is already taken.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if services()
|
if services().globals.forbidden_usernames().is_match(proposed_user_id.localpart()) {
|
||||||
.globals
|
return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden."));
|
||||||
.forbidden_usernames()
|
|
||||||
.is_match(proposed_user_id.localpart())
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Unknown,
|
|
||||||
"Username is forbidden.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
proposed_user_id
|
proposed_user_id
|
||||||
}
|
},
|
||||||
_ => loop {
|
_ => loop {
|
||||||
let proposed_user_id = UserId::parse_with_server_name(
|
let proposed_user_id = UserId::parse_with_server_name(
|
||||||
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
|
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
|
||||||
|
@ -196,8 +175,7 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
||||||
if !skip_auth {
|
if !skip_auth {
|
||||||
if let Some(auth) = &body.auth {
|
if let Some(auth) = &body.auth {
|
||||||
let (worked, uiaainfo) = services().uiaa.try_auth(
|
let (worked, uiaainfo) = services().uiaa.try_auth(
|
||||||
&UserId::parse_with_server_name("", services().globals.server_name())
|
&UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"),
|
||||||
.expect("we know this is valid"),
|
|
||||||
"".into(),
|
"".into(),
|
||||||
auth,
|
auth,
|
||||||
&uiaainfo,
|
&uiaainfo,
|
||||||
|
@ -209,8 +187,7 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
||||||
} else if let Some(json) = body.json_body {
|
} else if let Some(json) = body.json_body {
|
||||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||||
services().uiaa.create(
|
services().uiaa.create(
|
||||||
&UserId::parse_with_server_name("", services().globals.server_name())
|
&UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"),
|
||||||
.expect("we know this is valid"),
|
|
||||||
"".into(),
|
"".into(),
|
||||||
&uiaainfo,
|
&uiaainfo,
|
||||||
&json,
|
&json,
|
||||||
|
@ -233,15 +210,13 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
||||||
// Default to pretty displayname
|
// Default to pretty displayname
|
||||||
let mut displayname = user_id.localpart().to_owned();
|
let mut displayname = user_id.localpart().to_owned();
|
||||||
|
|
||||||
// If `new_user_displayname_suffix` is set, registration will push whatever content is set to the user's display name with a space before it
|
// If `new_user_displayname_suffix` is set, registration will push whatever
|
||||||
|
// content is set to the user's display name with a space before it
|
||||||
if !services().globals.new_user_displayname_suffix().is_empty() {
|
if !services().globals.new_user_displayname_suffix().is_empty() {
|
||||||
displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix()));
|
displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix()));
|
||||||
}
|
}
|
||||||
|
|
||||||
services()
|
services().users.set_displayname(&user_id, Some(displayname.clone())).await?;
|
||||||
.users
|
|
||||||
.set_displayname(&user_id, Some(displayname.clone()))
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Initial account data
|
// Initial account data
|
||||||
services().account_data.update(
|
services().account_data.update(
|
||||||
|
@ -279,41 +254,29 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
||||||
let token = utils::random_string(TOKEN_LENGTH);
|
let token = utils::random_string(TOKEN_LENGTH);
|
||||||
|
|
||||||
// Create device for this account
|
// Create device for this account
|
||||||
services().users.create_device(
|
services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
|
||||||
&user_id,
|
|
||||||
&device_id,
|
|
||||||
&token,
|
|
||||||
body.initial_device_display_name.clone(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
info!("New user \"{}\" registered on this server.", user_id);
|
info!("New user \"{}\" registered on this server.", user_id);
|
||||||
|
|
||||||
// log in conduit admin channel if a non-guest user registered
|
// log in conduit admin channel if a non-guest user registered
|
||||||
if !body.from_appservice && !is_guest {
|
if !body.from_appservice && !is_guest {
|
||||||
services()
|
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||||
.admin
|
|
||||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
|
||||||
"New user \"{user_id}\" registered on this server."
|
"New user \"{user_id}\" registered on this server."
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// log in conduit admin channel if a guest registered
|
// log in conduit admin channel if a guest registered
|
||||||
if !body.from_appservice && is_guest {
|
if !body.from_appservice && is_guest {
|
||||||
services()
|
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||||
.admin
|
|
||||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
|
||||||
"Guest user \"{user_id}\" with device display name `{:?}` registered on this server.",
|
"Guest user \"{user_id}\" with device display name `{:?}` registered on this server.",
|
||||||
body.initial_device_display_name
|
body.initial_device_display_name
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
|
|
||||||
// If this is the first real user, grant them admin privileges except for guest users
|
// If this is the first real user, grant them admin privileges except for guest
|
||||||
// Note: the server user, @conduit:servername, is generated first
|
// users Note: the server user, @conduit:servername, is generated first
|
||||||
if services().users.count()? == 2 && !is_guest {
|
if services().users.count()? == 2 && !is_guest {
|
||||||
services()
|
services().admin.make_user_admin(&user_id, displayname).await?;
|
||||||
.admin
|
|
||||||
.make_user_admin(&user_id, displayname)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
warn!("Granting {} admin privileges as the first user", user_id);
|
warn!("Granting {} admin privileges as the first user", user_id);
|
||||||
}
|
}
|
||||||
|
@ -333,17 +296,18 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
||||||
///
|
///
|
||||||
/// - Requires UIAA to verify user password
|
/// - Requires UIAA to verify user password
|
||||||
/// - Changes the password of the sender user
|
/// - Changes the password of the sender user
|
||||||
/// - The password hash is calculated using argon2 with 32 character salt, the plain password is
|
/// - The password hash is calculated using argon2 with 32 character salt, the
|
||||||
|
/// plain password is
|
||||||
/// not saved
|
/// not saved
|
||||||
///
|
///
|
||||||
/// If logout_devices is true it does the following for each device except the sender device:
|
/// If logout_devices is true it does the following for each device except the
|
||||||
|
/// sender device:
|
||||||
/// - Invalidates access token
|
/// - Invalidates access token
|
||||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||||
|
/// last seen ts)
|
||||||
/// - Forgets to-device events
|
/// - Forgets to-device events
|
||||||
/// - Triggers device list updates
|
/// - Triggers device list updates
|
||||||
pub async fn change_password_route(
|
pub async fn change_password_route(body: Ruma<change_password::v3::Request>) -> Result<change_password::v3::Response> {
|
||||||
body: Ruma<change_password::v3::Request>,
|
|
||||||
) -> Result<change_password::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
|
@ -358,27 +322,20 @@ pub async fn change_password_route(
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(auth) = &body.auth {
|
if let Some(auth) = &body.auth {
|
||||||
let (worked, uiaainfo) =
|
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||||
services()
|
|
||||||
.uiaa
|
|
||||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
|
||||||
if !worked {
|
if !worked {
|
||||||
return Err(Error::Uiaa(uiaainfo));
|
return Err(Error::Uiaa(uiaainfo));
|
||||||
}
|
}
|
||||||
// Success!
|
// Success!
|
||||||
} else if let Some(json) = body.json_body {
|
} else if let Some(json) = body.json_body {
|
||||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||||
services()
|
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||||
.uiaa
|
|
||||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
|
||||||
return Err(Error::Uiaa(uiaainfo));
|
return Err(Error::Uiaa(uiaainfo));
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||||
}
|
}
|
||||||
|
|
||||||
services()
|
services().users.set_password(sender_user, Some(&body.new_password))?;
|
||||||
.users
|
|
||||||
.set_password(sender_user, Some(&body.new_password))?;
|
|
||||||
|
|
||||||
if body.logout_devices {
|
if body.logout_devices {
|
||||||
// Logout all devices except the current one
|
// Logout all devices except the current one
|
||||||
|
@ -393,9 +350,7 @@ pub async fn change_password_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
info!("User {} changed their password.", sender_user);
|
info!("User {} changed their password.", sender_user);
|
||||||
services()
|
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||||
.admin
|
|
||||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
|
||||||
"User {sender_user} changed their password."
|
"User {sender_user} changed their password."
|
||||||
)));
|
)));
|
||||||
|
|
||||||
|
@ -424,13 +379,12 @@ pub async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3:
|
||||||
///
|
///
|
||||||
/// - Leaves all rooms and rejects all invitations
|
/// - Leaves all rooms and rejects all invitations
|
||||||
/// - Invalidates all access tokens
|
/// - Invalidates all access tokens
|
||||||
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts)
|
/// - Deletes all device metadata (device id, device display name, last seen ip,
|
||||||
|
/// last seen ts)
|
||||||
/// - Forgets all to-device events
|
/// - Forgets all to-device events
|
||||||
/// - Triggers device list updates
|
/// - Triggers device list updates
|
||||||
/// - Removes ability to log in again
|
/// - Removes ability to log in again
|
||||||
pub async fn deactivate_route(
|
pub async fn deactivate_route(body: Ruma<deactivate::v3::Request>) -> Result<deactivate::v3::Response> {
|
||||||
body: Ruma<deactivate::v3::Request>,
|
|
||||||
) -> Result<deactivate::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
|
@ -445,19 +399,14 @@ pub async fn deactivate_route(
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(auth) = &body.auth {
|
if let Some(auth) = &body.auth {
|
||||||
let (worked, uiaainfo) =
|
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||||
services()
|
|
||||||
.uiaa
|
|
||||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
|
||||||
if !worked {
|
if !worked {
|
||||||
return Err(Error::Uiaa(uiaainfo));
|
return Err(Error::Uiaa(uiaainfo));
|
||||||
}
|
}
|
||||||
// Success!
|
// Success!
|
||||||
} else if let Some(json) = body.json_body {
|
} else if let Some(json) = body.json_body {
|
||||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||||
services()
|
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||||
.uiaa
|
|
||||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
|
||||||
return Err(Error::Uiaa(uiaainfo));
|
return Err(Error::Uiaa(uiaainfo));
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||||
|
@ -470,9 +419,7 @@ pub async fn deactivate_route(
|
||||||
services().users.deactivate_account(sender_user)?;
|
services().users.deactivate_account(sender_user)?;
|
||||||
|
|
||||||
info!("User {} deactivated their account.", sender_user);
|
info!("User {} deactivated their account.", sender_user);
|
||||||
services()
|
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||||
.admin
|
|
||||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
|
||||||
"User {sender_user} deactivated their account."
|
"User {sender_user} deactivated their account."
|
||||||
)));
|
)));
|
||||||
|
|
||||||
|
@ -486,9 +433,7 @@ pub async fn deactivate_route(
|
||||||
/// Get a list of third party identifiers associated with this account.
|
/// Get a list of third party identifiers associated with this account.
|
||||||
///
|
///
|
||||||
/// - Currently always returns empty list
|
/// - Currently always returns empty list
|
||||||
pub async fn third_party_route(
|
pub async fn third_party_route(body: Ruma<get_3pids::v3::Request>) -> Result<get_3pids::v3::Response> {
|
||||||
body: Ruma<get_3pids::v3::Request>,
|
|
||||||
) -> Result<get_3pids::v3::Response> {
|
|
||||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
Ok(get_3pids::v3::Response::new(Vec::new()))
|
Ok(get_3pids::v3::Response::new(Vec::new()))
|
||||||
|
@ -496,9 +441,11 @@ pub async fn third_party_route(
|
||||||
|
|
||||||
/// # `POST /_matrix/client/v3/account/3pid/email/requestToken`
|
/// # `POST /_matrix/client/v3/account/3pid/email/requestToken`
|
||||||
///
|
///
|
||||||
/// "This API should be used to request validation tokens when adding an email address to an account"
|
/// "This API should be used to request validation tokens when adding an email
|
||||||
|
/// address to an account"
|
||||||
///
|
///
|
||||||
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option.
|
/// - 403 signals that The homeserver does not allow the third party identifier
|
||||||
|
/// as a contact option.
|
||||||
pub async fn request_3pid_management_token_via_email_route(
|
pub async fn request_3pid_management_token_via_email_route(
|
||||||
_body: Ruma<request_3pid_management_token_via_email::v3::Request>,
|
_body: Ruma<request_3pid_management_token_via_email::v3::Request>,
|
||||||
) -> Result<request_3pid_management_token_via_email::v3::Response> {
|
) -> Result<request_3pid_management_token_via_email::v3::Response> {
|
||||||
|
@ -510,9 +457,11 @@ pub async fn request_3pid_management_token_via_email_route(
|
||||||
|
|
||||||
/// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken`
|
/// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken`
|
||||||
///
|
///
|
||||||
/// "This API should be used to request validation tokens when adding an phone number to an account"
|
/// "This API should be used to request validation tokens when adding an phone
|
||||||
|
/// number to an account"
|
||||||
///
|
///
|
||||||
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option.
|
/// - 403 signals that The homeserver does not allow the third party identifier
|
||||||
|
/// as a contact option.
|
||||||
pub async fn request_3pid_management_token_via_msisdn_route(
|
pub async fn request_3pid_management_token_via_msisdn_route(
|
||||||
_body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>,
|
_body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>,
|
||||||
) -> Result<request_3pid_management_token_via_msisdn::v3::Response> {
|
) -> Result<request_3pid_management_token_via_msisdn::v3::Response> {
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use crate::{services, Error, Result, Ruma};
|
|
||||||
use rand::seq::SliceRandom;
|
use rand::seq::SliceRandom;
|
||||||
use regex::Regex;
|
use regex::Regex;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
|
@ -13,45 +12,25 @@ use ruma::{
|
||||||
OwnedRoomAliasId, OwnedServerName,
|
OwnedRoomAliasId, OwnedServerName,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/v3/directory/room/{roomAlias}`
|
/// # `PUT /_matrix/client/v3/directory/room/{roomAlias}`
|
||||||
///
|
///
|
||||||
/// Creates a new room alias on this server.
|
/// Creates a new room alias on this server.
|
||||||
pub async fn create_alias_route(
|
pub async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> Result<create_alias::v3::Response> {
|
||||||
body: Ruma<create_alias::v3::Request>,
|
|
||||||
) -> Result<create_alias::v3::Response> {
|
|
||||||
if body.room_alias.server_name() != services().globals.server_name() {
|
if body.room_alias.server_name() != services().globals.server_name() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Alias is from another server.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if services()
|
if services().globals.forbidden_room_names().is_match(body.room_alias.alias()) {
|
||||||
.globals
|
return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias is forbidden."));
|
||||||
.forbidden_room_names()
|
|
||||||
.is_match(body.room_alias.alias())
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Unknown,
|
|
||||||
"Room alias is forbidden.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if services()
|
if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() {
|
||||||
.rooms
|
|
||||||
.alias
|
|
||||||
.resolve_local_alias(&body.room_alias)?
|
|
||||||
.is_some()
|
|
||||||
{
|
|
||||||
return Err(Error::Conflict("Alias already exists."));
|
return Err(Error::Conflict("Alias already exists."));
|
||||||
}
|
}
|
||||||
|
|
||||||
if services()
|
if services().rooms.alias.set_alias(&body.room_alias, &body.room_id).is_err() {
|
||||||
.rooms
|
|
||||||
.alias
|
|
||||||
.set_alias(&body.room_alias, &body.room_id)
|
|
||||||
.is_err()
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
|
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
|
||||||
|
@ -67,34 +46,16 @@ pub async fn create_alias_route(
|
||||||
///
|
///
|
||||||
/// - TODO: additional access control checks
|
/// - TODO: additional access control checks
|
||||||
/// - TODO: Update canonical alias event
|
/// - TODO: Update canonical alias event
|
||||||
pub async fn delete_alias_route(
|
pub async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> Result<delete_alias::v3::Response> {
|
||||||
body: Ruma<delete_alias::v3::Request>,
|
|
||||||
) -> Result<delete_alias::v3::Response> {
|
|
||||||
if body.room_alias.server_name() != services().globals.server_name() {
|
if body.room_alias.server_name() != services().globals.server_name() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Alias is from another server.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if services()
|
if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_none() {
|
||||||
.rooms
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
|
||||||
.alias
|
|
||||||
.resolve_local_alias(&body.room_alias)?
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Alias does not exist.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if services()
|
if services().rooms.alias.remove_alias(&body.room_alias).is_err() {
|
||||||
.rooms
|
|
||||||
.alias
|
|
||||||
.remove_alias(&body.room_alias)
|
|
||||||
.is_err()
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
|
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
|
||||||
|
@ -109,15 +70,11 @@ pub async fn delete_alias_route(
|
||||||
/// # `GET /_matrix/client/v3/directory/room/{roomAlias}`
|
/// # `GET /_matrix/client/v3/directory/room/{roomAlias}`
|
||||||
///
|
///
|
||||||
/// Resolve an alias locally or over federation.
|
/// Resolve an alias locally or over federation.
|
||||||
pub async fn get_alias_route(
|
pub async fn get_alias_route(body: Ruma<get_alias::v3::Request>) -> Result<get_alias::v3::Response> {
|
||||||
body: Ruma<get_alias::v3::Request>,
|
|
||||||
) -> Result<get_alias::v3::Response> {
|
|
||||||
get_alias_helper(body.body.room_alias).await
|
get_alias_helper(body.body.room_alias).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get_alias_helper(
|
pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get_alias::v3::Response> {
|
||||||
room_alias: OwnedRoomAliasId,
|
|
||||||
) -> Result<get_alias::v3::Response> {
|
|
||||||
if room_alias.server_name() != services().globals.server_name() {
|
if room_alias.server_name() != services().globals.server_name() {
|
||||||
let response = services()
|
let response = services()
|
||||||
.sending
|
.sending
|
||||||
|
@ -134,20 +91,13 @@ pub(crate) async fn get_alias_helper(
|
||||||
let mut servers = response.servers;
|
let mut servers = response.servers;
|
||||||
|
|
||||||
// find active servers in room state cache to suggest
|
// find active servers in room state cache to suggest
|
||||||
for extra_servers in services()
|
for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(std::result::Result::ok) {
|
||||||
.rooms
|
|
||||||
.state_cache
|
|
||||||
.room_servers(&room_id)
|
|
||||||
.filter_map(std::result::Result::ok)
|
|
||||||
{
|
|
||||||
servers.push(extra_servers);
|
servers.push(extra_servers);
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert our server as the very first choice if in list
|
// insert our server as the very first choice if in list
|
||||||
if let Some(server_index) = servers
|
if let Some(server_index) =
|
||||||
.clone()
|
servers.clone().into_iter().position(|server| server == services().globals.server_name())
|
||||||
.into_iter()
|
|
||||||
.position(|server| server == services().globals.server_name())
|
|
||||||
{
|
{
|
||||||
servers.remove(server_index);
|
servers.remove(server_index);
|
||||||
servers.insert(0, services().globals.server_name().to_owned());
|
servers.insert(0, services().globals.server_name().to_owned());
|
||||||
|
@ -174,9 +124,7 @@ pub(crate) async fn get_alias_helper(
|
||||||
.filter_map(|alias| Regex::new(alias.regex.as_str()).ok())
|
.filter_map(|alias| Regex::new(alias.regex.as_str()).ok())
|
||||||
.collect::<Vec<_>>();
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
if aliases
|
if aliases.iter().any(|aliases| aliases.is_match(room_alias.as_str()))
|
||||||
.iter()
|
|
||||||
.any(|aliases| aliases.is_match(room_alias.as_str()))
|
|
||||||
&& if let Some(opt_result) = services()
|
&& if let Some(opt_result) = services()
|
||||||
.sending
|
.sending
|
||||||
.send_appservice_request(
|
.send_appservice_request(
|
||||||
|
@ -190,50 +138,35 @@ pub(crate) async fn get_alias_helper(
|
||||||
opt_result.is_ok()
|
opt_result.is_ok()
|
||||||
} else {
|
} else {
|
||||||
false
|
false
|
||||||
}
|
} {
|
||||||
{
|
|
||||||
room_id = Some(
|
room_id = Some(
|
||||||
services()
|
services()
|
||||||
.rooms
|
.rooms
|
||||||
.alias
|
.alias
|
||||||
.resolve_local_alias(&room_alias)?
|
.resolve_local_alias(&room_alias)?
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| Error::bad_config("Appservice lied to us. Room does not exist."))?,
|
||||||
Error::bad_config("Appservice lied to us. Room does not exist.")
|
|
||||||
})?,
|
|
||||||
);
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let room_id = match room_id {
|
let room_id = match room_id {
|
||||||
Some(room_id) => room_id,
|
Some(room_id) => room_id,
|
||||||
None => {
|
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")),
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Room with alias not found.",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut servers: Vec<OwnedServerName> = Vec::new();
|
let mut servers: Vec<OwnedServerName> = Vec::new();
|
||||||
|
|
||||||
// find active servers in room state cache to suggest
|
// find active servers in room state cache to suggest
|
||||||
for extra_servers in services()
|
for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(std::result::Result::ok) {
|
||||||
.rooms
|
|
||||||
.state_cache
|
|
||||||
.room_servers(&room_id)
|
|
||||||
.filter_map(std::result::Result::ok)
|
|
||||||
{
|
|
||||||
servers.push(extra_servers);
|
servers.push(extra_servers);
|
||||||
}
|
}
|
||||||
|
|
||||||
// insert our server as the very first choice if in list
|
// insert our server as the very first choice if in list
|
||||||
if let Some(server_index) = servers
|
if let Some(server_index) =
|
||||||
.clone()
|
servers.clone().into_iter().position(|server| server == services().globals.server_name())
|
||||||
.into_iter()
|
|
||||||
.position(|server| server == services().globals.server_name())
|
|
||||||
{
|
{
|
||||||
servers.remove(server_index);
|
servers.remove(server_index);
|
||||||
servers.insert(0, services().globals.server_name().to_owned());
|
servers.insert(0, services().globals.server_name().to_owned());
|
||||||
|
|
|
@ -1,15 +1,15 @@
|
||||||
use crate::{services, Error, Result, Ruma};
|
|
||||||
use ruma::api::client::{
|
use ruma::api::client::{
|
||||||
backup::{
|
backup::{
|
||||||
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session,
|
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version,
|
||||||
create_backup_version, delete_backup_keys, delete_backup_keys_for_room,
|
delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version,
|
||||||
delete_backup_keys_for_session, delete_backup_version, get_backup_info, get_backup_keys,
|
get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session,
|
||||||
get_backup_keys_for_room, get_backup_keys_for_session, get_latest_backup_info,
|
get_latest_backup_info, update_backup_version,
|
||||||
update_backup_version,
|
|
||||||
},
|
},
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `POST /_matrix/client/r0/room_keys/version`
|
/// # `POST /_matrix/client/r0/room_keys/version`
|
||||||
///
|
///
|
||||||
/// Creates a new backup.
|
/// Creates a new backup.
|
||||||
|
@ -17,23 +17,22 @@ pub async fn create_backup_version_route(
|
||||||
body: Ruma<create_backup_version::v3::Request>,
|
body: Ruma<create_backup_version::v3::Request>,
|
||||||
) -> Result<create_backup_version::v3::Response> {
|
) -> Result<create_backup_version::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let version = services()
|
let version = services().key_backups.create_backup(sender_user, &body.algorithm)?;
|
||||||
.key_backups
|
|
||||||
.create_backup(sender_user, &body.algorithm)?;
|
|
||||||
|
|
||||||
Ok(create_backup_version::v3::Response { version })
|
Ok(create_backup_version::v3::Response {
|
||||||
|
version,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/room_keys/version/{version}`
|
/// # `PUT /_matrix/client/r0/room_keys/version/{version}`
|
||||||
///
|
///
|
||||||
/// Update information about an existing backup. Only `auth_data` can be modified.
|
/// Update information about an existing backup. Only `auth_data` can be
|
||||||
|
/// modified.
|
||||||
pub async fn update_backup_version_route(
|
pub async fn update_backup_version_route(
|
||||||
body: Ruma<update_backup_version::v3::Request>,
|
body: Ruma<update_backup_version::v3::Request>,
|
||||||
) -> Result<update_backup_version::v3::Response> {
|
) -> Result<update_backup_version::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
services()
|
services().key_backups.update_backup(sender_user, &body.version, &body.algorithm)?;
|
||||||
.key_backups
|
|
||||||
.update_backup(sender_user, &body.version, &body.algorithm)?;
|
|
||||||
|
|
||||||
Ok(update_backup_version::v3::Response {})
|
Ok(update_backup_version::v3::Response {})
|
||||||
}
|
}
|
||||||
|
@ -49,10 +48,7 @@ pub async fn get_latest_backup_info_route(
|
||||||
let (version, algorithm) = services()
|
let (version, algorithm) = services()
|
||||||
.key_backups
|
.key_backups
|
||||||
.get_latest_backup(sender_user)?
|
.get_latest_backup(sender_user)?
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Key backup does not exist.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
Ok(get_latest_backup_info::v3::Response {
|
Ok(get_latest_backup_info::v3::Response {
|
||||||
algorithm,
|
algorithm,
|
||||||
|
@ -65,27 +61,17 @@ pub async fn get_latest_backup_info_route(
|
||||||
/// # `GET /_matrix/client/r0/room_keys/version`
|
/// # `GET /_matrix/client/r0/room_keys/version`
|
||||||
///
|
///
|
||||||
/// Get information about an existing backup.
|
/// Get information about an existing backup.
|
||||||
pub async fn get_backup_info_route(
|
pub async fn get_backup_info_route(body: Ruma<get_backup_info::v3::Request>) -> Result<get_backup_info::v3::Response> {
|
||||||
body: Ruma<get_backup_info::v3::Request>,
|
|
||||||
) -> Result<get_backup_info::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let algorithm = services()
|
let algorithm = services()
|
||||||
.key_backups
|
.key_backups
|
||||||
.get_backup(sender_user, &body.version)?
|
.get_backup(sender_user, &body.version)?
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Key backup does not exist.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
Ok(get_backup_info::v3::Response {
|
Ok(get_backup_info::v3::Response {
|
||||||
algorithm,
|
algorithm,
|
||||||
count: (services()
|
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||||
.key_backups
|
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||||
.count_keys(sender_user, &body.version)? as u32)
|
|
||||||
.into(),
|
|
||||||
etag: services()
|
|
||||||
.key_backups
|
|
||||||
.get_etag(sender_user, &body.version)?,
|
|
||||||
version: body.version.clone(),
|
version: body.version.clone(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -94,15 +80,14 @@ pub async fn get_backup_info_route(
|
||||||
///
|
///
|
||||||
/// Delete an existing key backup.
|
/// Delete an existing key backup.
|
||||||
///
|
///
|
||||||
/// - Deletes both information about the backup, as well as all key data related to the backup
|
/// - Deletes both information about the backup, as well as all key data related
|
||||||
|
/// to the backup
|
||||||
pub async fn delete_backup_version_route(
|
pub async fn delete_backup_version_route(
|
||||||
body: Ruma<delete_backup_version::v3::Request>,
|
body: Ruma<delete_backup_version::v3::Request>,
|
||||||
) -> Result<delete_backup_version::v3::Response> {
|
) -> Result<delete_backup_version::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
services()
|
services().key_backups.delete_backup(sender_user, &body.version)?;
|
||||||
.key_backups
|
|
||||||
.delete_backup(sender_user, &body.version)?;
|
|
||||||
|
|
||||||
Ok(delete_backup_version::v3::Response {})
|
Ok(delete_backup_version::v3::Response {})
|
||||||
}
|
}
|
||||||
|
@ -111,20 +96,14 @@ pub async fn delete_backup_version_route(
|
||||||
///
|
///
|
||||||
/// Add the received backup keys to the database.
|
/// Add the received backup keys to the database.
|
||||||
///
|
///
|
||||||
/// - Only manipulating the most recently created version of the backup is allowed
|
/// - Only manipulating the most recently created version of the backup is
|
||||||
|
/// allowed
|
||||||
/// - Adds the keys to the backup
|
/// - Adds the keys to the backup
|
||||||
/// - Returns the new number of keys in this backup and the etag
|
/// - Returns the new number of keys in this backup and the etag
|
||||||
pub async fn add_backup_keys_route(
|
pub async fn add_backup_keys_route(body: Ruma<add_backup_keys::v3::Request>) -> Result<add_backup_keys::v3::Response> {
|
||||||
body: Ruma<add_backup_keys::v3::Request>,
|
|
||||||
) -> Result<add_backup_keys::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if Some(&body.version)
|
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
|
||||||
!= services()
|
|
||||||
.key_backups
|
|
||||||
.get_latest_backup_version(sender_user)?
|
|
||||||
.as_ref()
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"You may only manipulate the most recently created version of the backup.",
|
"You may only manipulate the most recently created version of the backup.",
|
||||||
|
@ -133,24 +112,13 @@ pub async fn add_backup_keys_route(
|
||||||
|
|
||||||
for (room_id, room) in &body.rooms {
|
for (room_id, room) in &body.rooms {
|
||||||
for (session_id, key_data) in &room.sessions {
|
for (session_id, key_data) in &room.sessions {
|
||||||
services().key_backups.add_key(
|
services().key_backups.add_key(sender_user, &body.version, room_id, session_id, key_data)?;
|
||||||
sender_user,
|
|
||||||
&body.version,
|
|
||||||
room_id,
|
|
||||||
session_id,
|
|
||||||
key_data,
|
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(add_backup_keys::v3::Response {
|
Ok(add_backup_keys::v3::Response {
|
||||||
count: (services()
|
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||||
.key_backups
|
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||||
.count_keys(sender_user, &body.version)? as u32)
|
|
||||||
.into(),
|
|
||||||
etag: services()
|
|
||||||
.key_backups
|
|
||||||
.get_etag(sender_user, &body.version)?,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -158,7 +126,8 @@ pub async fn add_backup_keys_route(
|
||||||
///
|
///
|
||||||
/// Add the received backup keys to the database.
|
/// Add the received backup keys to the database.
|
||||||
///
|
///
|
||||||
/// - Only manipulating the most recently created version of the backup is allowed
|
/// - Only manipulating the most recently created version of the backup is
|
||||||
|
/// allowed
|
||||||
/// - Adds the keys to the backup
|
/// - Adds the keys to the backup
|
||||||
/// - Returns the new number of keys in this backup and the etag
|
/// - Returns the new number of keys in this backup and the etag
|
||||||
pub async fn add_backup_keys_for_room_route(
|
pub async fn add_backup_keys_for_room_route(
|
||||||
|
@ -166,12 +135,7 @@ pub async fn add_backup_keys_for_room_route(
|
||||||
) -> Result<add_backup_keys_for_room::v3::Response> {
|
) -> Result<add_backup_keys_for_room::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if Some(&body.version)
|
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
|
||||||
!= services()
|
|
||||||
.key_backups
|
|
||||||
.get_latest_backup_version(sender_user)?
|
|
||||||
.as_ref()
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"You may only manipulate the most recently created version of the backup.",
|
"You may only manipulate the most recently created version of the backup.",
|
||||||
|
@ -179,23 +143,12 @@ pub async fn add_backup_keys_for_room_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
for (session_id, key_data) in &body.sessions {
|
for (session_id, key_data) in &body.sessions {
|
||||||
services().key_backups.add_key(
|
services().key_backups.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?;
|
||||||
sender_user,
|
|
||||||
&body.version,
|
|
||||||
&body.room_id,
|
|
||||||
session_id,
|
|
||||||
key_data,
|
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(add_backup_keys_for_room::v3::Response {
|
Ok(add_backup_keys_for_room::v3::Response {
|
||||||
count: (services()
|
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||||
.key_backups
|
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||||
.count_keys(sender_user, &body.version)? as u32)
|
|
||||||
.into(),
|
|
||||||
etag: services()
|
|
||||||
.key_backups
|
|
||||||
.get_etag(sender_user, &body.version)?,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -203,7 +156,8 @@ pub async fn add_backup_keys_for_room_route(
|
||||||
///
|
///
|
||||||
/// Add the received backup key to the database.
|
/// Add the received backup key to the database.
|
||||||
///
|
///
|
||||||
/// - Only manipulating the most recently created version of the backup is allowed
|
/// - Only manipulating the most recently created version of the backup is
|
||||||
|
/// allowed
|
||||||
/// - Adds the keys to the backup
|
/// - Adds the keys to the backup
|
||||||
/// - Returns the new number of keys in this backup and the etag
|
/// - Returns the new number of keys in this backup and the etag
|
||||||
pub async fn add_backup_keys_for_session_route(
|
pub async fn add_backup_keys_for_session_route(
|
||||||
|
@ -211,48 +165,32 @@ pub async fn add_backup_keys_for_session_route(
|
||||||
) -> Result<add_backup_keys_for_session::v3::Response> {
|
) -> Result<add_backup_keys_for_session::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if Some(&body.version)
|
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
|
||||||
!= services()
|
|
||||||
.key_backups
|
|
||||||
.get_latest_backup_version(sender_user)?
|
|
||||||
.as_ref()
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"You may only manipulate the most recently created version of the backup.",
|
"You may only manipulate the most recently created version of the backup.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
services().key_backups.add_key(
|
services().key_backups.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?;
|
||||||
sender_user,
|
|
||||||
&body.version,
|
|
||||||
&body.room_id,
|
|
||||||
&body.session_id,
|
|
||||||
&body.session_data,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(add_backup_keys_for_session::v3::Response {
|
Ok(add_backup_keys_for_session::v3::Response {
|
||||||
count: (services()
|
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||||
.key_backups
|
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||||
.count_keys(sender_user, &body.version)? as u32)
|
|
||||||
.into(),
|
|
||||||
etag: services()
|
|
||||||
.key_backups
|
|
||||||
.get_etag(sender_user, &body.version)?,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/room_keys/keys`
|
/// # `GET /_matrix/client/r0/room_keys/keys`
|
||||||
///
|
///
|
||||||
/// Retrieves all keys from the backup.
|
/// Retrieves all keys from the backup.
|
||||||
pub async fn get_backup_keys_route(
|
pub async fn get_backup_keys_route(body: Ruma<get_backup_keys::v3::Request>) -> Result<get_backup_keys::v3::Response> {
|
||||||
body: Ruma<get_backup_keys::v3::Request>,
|
|
||||||
) -> Result<get_backup_keys::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let rooms = services().key_backups.get_all(sender_user, &body.version)?;
|
let rooms = services().key_backups.get_all(sender_user, &body.version)?;
|
||||||
|
|
||||||
Ok(get_backup_keys::v3::Response { rooms })
|
Ok(get_backup_keys::v3::Response {
|
||||||
|
rooms,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}`
|
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}`
|
||||||
|
@ -263,11 +201,11 @@ pub async fn get_backup_keys_for_room_route(
|
||||||
) -> Result<get_backup_keys_for_room::v3::Response> {
|
) -> Result<get_backup_keys_for_room::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let sessions = services()
|
let sessions = services().key_backups.get_room(sender_user, &body.version, &body.room_id)?;
|
||||||
.key_backups
|
|
||||||
.get_room(sender_user, &body.version, &body.room_id)?;
|
|
||||||
|
|
||||||
Ok(get_backup_keys_for_room::v3::Response { sessions })
|
Ok(get_backup_keys_for_room::v3::Response {
|
||||||
|
sessions,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
|
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
|
||||||
|
@ -278,15 +216,14 @@ pub async fn get_backup_keys_for_session_route(
|
||||||
) -> Result<get_backup_keys_for_session::v3::Response> {
|
) -> Result<get_backup_keys_for_session::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let key_data = services()
|
let key_data =
|
||||||
.key_backups
|
services().key_backups.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?.ok_or(
|
||||||
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?
|
Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."),
|
||||||
.ok_or(Error::BadRequest(
|
)?;
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Backup key not found for this user's session.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
Ok(get_backup_keys_for_session::v3::Response { key_data })
|
Ok(get_backup_keys_for_session::v3::Response {
|
||||||
|
key_data,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `DELETE /_matrix/client/r0/room_keys/keys`
|
/// # `DELETE /_matrix/client/r0/room_keys/keys`
|
||||||
|
@ -297,18 +234,11 @@ pub async fn delete_backup_keys_route(
|
||||||
) -> Result<delete_backup_keys::v3::Response> {
|
) -> Result<delete_backup_keys::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
services()
|
services().key_backups.delete_all_keys(sender_user, &body.version)?;
|
||||||
.key_backups
|
|
||||||
.delete_all_keys(sender_user, &body.version)?;
|
|
||||||
|
|
||||||
Ok(delete_backup_keys::v3::Response {
|
Ok(delete_backup_keys::v3::Response {
|
||||||
count: (services()
|
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||||
.key_backups
|
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||||
.count_keys(sender_user, &body.version)? as u32)
|
|
||||||
.into(),
|
|
||||||
etag: services()
|
|
||||||
.key_backups
|
|
||||||
.get_etag(sender_user, &body.version)?,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -320,18 +250,11 @@ pub async fn delete_backup_keys_for_room_route(
|
||||||
) -> Result<delete_backup_keys_for_room::v3::Response> {
|
) -> Result<delete_backup_keys_for_room::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
services()
|
services().key_backups.delete_room_keys(sender_user, &body.version, &body.room_id)?;
|
||||||
.key_backups
|
|
||||||
.delete_room_keys(sender_user, &body.version, &body.room_id)?;
|
|
||||||
|
|
||||||
Ok(delete_backup_keys_for_room::v3::Response {
|
Ok(delete_backup_keys_for_room::v3::Response {
|
||||||
count: (services()
|
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||||
.key_backups
|
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||||
.count_keys(sender_user, &body.version)? as u32)
|
|
||||||
.into(),
|
|
||||||
etag: services()
|
|
||||||
.key_backups
|
|
||||||
.get_etag(sender_user, &body.version)?,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -343,20 +266,10 @@ pub async fn delete_backup_keys_for_session_route(
|
||||||
) -> Result<delete_backup_keys_for_session::v3::Response> {
|
) -> Result<delete_backup_keys_for_session::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
services().key_backups.delete_room_key(
|
services().key_backups.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?;
|
||||||
sender_user,
|
|
||||||
&body.version,
|
|
||||||
&body.room_id,
|
|
||||||
&body.session_id,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(delete_backup_keys_for_session::v3::Response {
|
Ok(delete_backup_keys_for_session::v3::Response {
|
||||||
count: (services()
|
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||||
.key_backups
|
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||||
.count_keys(sender_user, &body.version)? as u32)
|
|
||||||
.into(),
|
|
||||||
etag: services()
|
|
||||||
.key_backups
|
|
||||||
.get_etag(sender_user, &body.version)?,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,12 +1,15 @@
|
||||||
use crate::{services, Result, Ruma};
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use ruma::api::client::discovery::get_capabilities::{
|
use ruma::api::client::discovery::get_capabilities::{
|
||||||
self, Capabilities, RoomVersionStability, RoomVersionsCapability,
|
self, Capabilities, RoomVersionStability, RoomVersionsCapability,
|
||||||
};
|
};
|
||||||
use std::collections::BTreeMap;
|
|
||||||
|
use crate::{services, Result, Ruma};
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/capabilities`
|
/// # `GET /_matrix/client/r0/capabilities`
|
||||||
///
|
///
|
||||||
/// Get information on the supported feature set and other relevent capabilities of this server.
|
/// Get information on the supported feature set and other relevent capabilities
|
||||||
|
/// of this server.
|
||||||
pub async fn get_capabilities_route(
|
pub async fn get_capabilities_route(
|
||||||
_body: Ruma<get_capabilities::v3::Request>,
|
_body: Ruma<get_capabilities::v3::Request>,
|
||||||
) -> Result<get_capabilities::v3::Response> {
|
) -> Result<get_capabilities::v3::Response> {
|
||||||
|
@ -24,5 +27,7 @@ pub async fn get_capabilities_route(
|
||||||
available,
|
available,
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(get_capabilities::v3::Response { capabilities })
|
Ok(get_capabilities::v3::Response {
|
||||||
|
capabilities,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,6 @@
|
||||||
use crate::{services, Error, Result, Ruma};
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{
|
api::client::{
|
||||||
config::{
|
config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data},
|
||||||
get_global_account_data, get_room_account_data, set_global_account_data,
|
|
||||||
set_room_account_data,
|
|
||||||
},
|
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
},
|
},
|
||||||
events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent},
|
events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent},
|
||||||
|
@ -13,6 +9,8 @@ use ruma::{
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use serde_json::{json, value::RawValue as RawJsonValue};
|
use serde_json::{json, value::RawValue as RawJsonValue};
|
||||||
|
|
||||||
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}`
|
/// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}`
|
||||||
///
|
///
|
||||||
/// Sets some account data for the sender user.
|
/// Sets some account data for the sender user.
|
||||||
|
@ -82,7 +80,9 @@ pub async fn get_global_account_data_route(
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||||
.content;
|
.content;
|
||||||
|
|
||||||
Ok(get_global_account_data::v3::Response { account_data })
|
Ok(get_global_account_data::v3::Response {
|
||||||
|
account_data,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}`
|
/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}`
|
||||||
|
@ -102,7 +102,9 @@ pub async fn get_room_account_data_route(
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||||
.content;
|
.content;
|
||||||
|
|
||||||
Ok(get_room_account_data::v3::Response { account_data })
|
Ok(get_room_account_data::v3::Response {
|
||||||
|
account_data,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
|
|
|
@ -1,20 +1,21 @@
|
||||||
use crate::{services, Error, Result, Ruma};
|
use std::collections::HashSet;
|
||||||
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
|
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
|
||||||
events::StateEventType,
|
events::StateEventType,
|
||||||
};
|
};
|
||||||
use std::collections::HashSet;
|
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/context`
|
/// # `GET /_matrix/client/r0/rooms/{roomId}/context`
|
||||||
///
|
///
|
||||||
/// Allows loading room history around an event.
|
/// Allows loading room history around an event.
|
||||||
///
|
///
|
||||||
/// - Only works if the user is joined (TODO: always allow, but only show events if the user was
|
/// - Only works if the user is joined (TODO: always allow, but only show events
|
||||||
|
/// if the user was
|
||||||
/// joined, depending on history_visibility)
|
/// joined, depending on history_visibility)
|
||||||
pub async fn get_context_route(
|
pub async fn get_context_route(body: Ruma<get_context::v3::Request>) -> Result<get_context::v3::Response> {
|
||||||
body: Ruma<get_context::v3::Request>,
|
|
||||||
) -> Result<get_context::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
|
@ -31,28 +32,17 @@ pub async fn get_context_route(
|
||||||
.rooms
|
.rooms
|
||||||
.timeline
|
.timeline
|
||||||
.get_pdu_count(&body.event_id)?
|
.get_pdu_count(&body.event_id)?
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?;
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Base event id not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
let base_event =
|
let base_event = services()
|
||||||
services()
|
|
||||||
.rooms
|
.rooms
|
||||||
.timeline
|
.timeline
|
||||||
.get_pdu(&body.event_id)?
|
.get_pdu(&body.event_id)?
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?;
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Base event not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
let room_id = base_event.room_id.clone();
|
let room_id = base_event.room_id.clone();
|
||||||
|
|
||||||
if !services()
|
if !services().rooms.state_accessor.user_can_see_event(sender_user, &room_id, &body.event_id)? {
|
||||||
.rooms
|
|
||||||
.state_accessor
|
|
||||||
.user_can_see_event(sender_user, &room_id, &body.event_id)?
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::Forbidden,
|
ErrorKind::Forbidden,
|
||||||
"You don't have permission to view this event.",
|
"You don't have permission to view this event.",
|
||||||
|
@ -101,15 +91,10 @@ pub async fn get_context_route(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let start_token = events_before
|
let start_token =
|
||||||
.last()
|
events_before.last().map(|(count, _)| count.stringify()).unwrap_or_else(|| base_token.stringify());
|
||||||
.map(|(count, _)| count.stringify())
|
|
||||||
.unwrap_or_else(|| base_token.stringify());
|
|
||||||
|
|
||||||
let events_before: Vec<_> = events_before
|
let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||||
.into_iter()
|
|
||||||
.map(|(_, pdu)| pdu.to_room_event())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let events_after: Vec<_> = services()
|
let events_after: Vec<_> = services()
|
||||||
.rooms
|
.rooms
|
||||||
|
@ -138,42 +123,25 @@ pub async fn get_context_route(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash(
|
let shortstatehash = match services()
|
||||||
events_after
|
|
||||||
.last()
|
|
||||||
.map_or(&*body.event_id, |(_, e)| &*e.event_id),
|
|
||||||
)? {
|
|
||||||
Some(s) => s,
|
|
||||||
None => services()
|
|
||||||
.rooms
|
|
||||||
.state
|
|
||||||
.get_room_shortstatehash(&room_id)?
|
|
||||||
.expect("All rooms have state"),
|
|
||||||
};
|
|
||||||
|
|
||||||
let state_ids = services()
|
|
||||||
.rooms
|
.rooms
|
||||||
.state_accessor
|
.state_accessor
|
||||||
.state_full_ids(shortstatehash)
|
.pdu_shortstatehash(events_after.last().map_or(&*body.event_id, |(_, e)| &*e.event_id))?
|
||||||
.await?;
|
{
|
||||||
|
Some(s) => s,
|
||||||
|
None => services().rooms.state.get_room_shortstatehash(&room_id)?.expect("All rooms have state"),
|
||||||
|
};
|
||||||
|
|
||||||
let end_token = events_after
|
let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?;
|
||||||
.last()
|
|
||||||
.map(|(count, _)| count.stringify())
|
|
||||||
.unwrap_or_else(|| base_token.stringify());
|
|
||||||
|
|
||||||
let events_after: Vec<_> = events_after
|
let end_token = events_after.last().map(|(count, _)| count.stringify()).unwrap_or_else(|| base_token.stringify());
|
||||||
.into_iter()
|
|
||||||
.map(|(_, pdu)| pdu.to_room_event())
|
let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||||
.collect();
|
|
||||||
|
|
||||||
let mut state = Vec::new();
|
let mut state = Vec::new();
|
||||||
|
|
||||||
for (shortstatekey, id) in state_ids {
|
for (shortstatekey, id) in state_ids {
|
||||||
let (event_type, state_key) = services()
|
let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?;
|
||||||
.rooms
|
|
||||||
.short
|
|
||||||
.get_statekey_from_short(shortstatekey)?;
|
|
||||||
|
|
||||||
if event_type != StateEventType::RoomMember {
|
if event_type != StateEventType::RoomMember {
|
||||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||||
|
@ -181,7 +149,7 @@ pub async fn get_context_route(
|
||||||
None => {
|
None => {
|
||||||
error!("Pdu in state not found: {}", id);
|
error!("Pdu in state not found: {}", id);
|
||||||
continue;
|
continue;
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
state.push(pdu.to_state_event());
|
state.push(pdu.to_state_event());
|
||||||
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
|
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
|
||||||
|
@ -190,7 +158,7 @@ pub async fn get_context_route(
|
||||||
None => {
|
None => {
|
||||||
error!("Pdu in state not found: {}", id);
|
error!("Pdu in state not found: {}", id);
|
||||||
continue;
|
continue;
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
state.push(pdu.to_state_event());
|
state.push(pdu.to_state_event());
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use crate::{services, utils, Error, Result, Ruma};
|
|
||||||
use ruma::api::client::{
|
use ruma::api::client::{
|
||||||
device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
|
device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
|
@ -6,13 +5,12 @@ use ruma::api::client::{
|
||||||
};
|
};
|
||||||
|
|
||||||
use super::SESSION_ID_LENGTH;
|
use super::SESSION_ID_LENGTH;
|
||||||
|
use crate::{services, utils, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/devices`
|
/// # `GET /_matrix/client/r0/devices`
|
||||||
///
|
///
|
||||||
/// Get metadata on all devices of the sender user.
|
/// Get metadata on all devices of the sender user.
|
||||||
pub async fn get_devices_route(
|
pub async fn get_devices_route(body: Ruma<get_devices::v3::Request>) -> Result<get_devices::v3::Response> {
|
||||||
body: Ruma<get_devices::v3::Request>,
|
|
||||||
) -> Result<get_devices::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let devices: Vec<device::Device> = services()
|
let devices: Vec<device::Device> = services()
|
||||||
|
@ -21,15 +19,15 @@ pub async fn get_devices_route(
|
||||||
.filter_map(std::result::Result::ok) // Filter out buggy devices
|
.filter_map(std::result::Result::ok) // Filter out buggy devices
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
Ok(get_devices::v3::Response { devices })
|
Ok(get_devices::v3::Response {
|
||||||
|
devices,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/devices/{deviceId}`
|
/// # `GET /_matrix/client/r0/devices/{deviceId}`
|
||||||
///
|
///
|
||||||
/// Get metadata on a single device of the sender user.
|
/// Get metadata on a single device of the sender user.
|
||||||
pub async fn get_device_route(
|
pub async fn get_device_route(body: Ruma<get_device::v3::Request>) -> Result<get_device::v3::Response> {
|
||||||
body: Ruma<get_device::v3::Request>,
|
|
||||||
) -> Result<get_device::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let device = services()
|
let device = services()
|
||||||
|
@ -37,15 +35,15 @@ pub async fn get_device_route(
|
||||||
.get_device_metadata(sender_user, &body.body.device_id)?
|
.get_device_metadata(sender_user, &body.body.device_id)?
|
||||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
|
||||||
|
|
||||||
Ok(get_device::v3::Response { device })
|
Ok(get_device::v3::Response {
|
||||||
|
device,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/devices/{deviceId}`
|
/// # `PUT /_matrix/client/r0/devices/{deviceId}`
|
||||||
///
|
///
|
||||||
/// Updates the metadata on a given device of the sender user.
|
/// Updates the metadata on a given device of the sender user.
|
||||||
pub async fn update_device_route(
|
pub async fn update_device_route(body: Ruma<update_device::v3::Request>) -> Result<update_device::v3::Response> {
|
||||||
body: Ruma<update_device::v3::Request>,
|
|
||||||
) -> Result<update_device::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let mut device = services()
|
let mut device = services()
|
||||||
|
@ -55,9 +53,7 @@ pub async fn update_device_route(
|
||||||
|
|
||||||
device.display_name = body.display_name.clone();
|
device.display_name = body.display_name.clone();
|
||||||
|
|
||||||
services()
|
services().users.update_device_metadata(sender_user, &body.device_id, &device)?;
|
||||||
.users
|
|
||||||
.update_device_metadata(sender_user, &body.device_id, &device)?;
|
|
||||||
|
|
||||||
Ok(update_device::v3::Response {})
|
Ok(update_device::v3::Response {})
|
||||||
}
|
}
|
||||||
|
@ -68,12 +64,11 @@ pub async fn update_device_route(
|
||||||
///
|
///
|
||||||
/// - Requires UIAA to verify user password
|
/// - Requires UIAA to verify user password
|
||||||
/// - Invalidates access token
|
/// - Invalidates access token
|
||||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||||
|
/// last seen ts)
|
||||||
/// - Forgets to-device events
|
/// - Forgets to-device events
|
||||||
/// - Triggers device list updates
|
/// - Triggers device list updates
|
||||||
pub async fn delete_device_route(
|
pub async fn delete_device_route(body: Ruma<delete_device::v3::Request>) -> Result<delete_device::v3::Response> {
|
||||||
body: Ruma<delete_device::v3::Request>,
|
|
||||||
) -> Result<delete_device::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
|
@ -89,27 +84,20 @@ pub async fn delete_device_route(
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(auth) = &body.auth {
|
if let Some(auth) = &body.auth {
|
||||||
let (worked, uiaainfo) =
|
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||||
services()
|
|
||||||
.uiaa
|
|
||||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
|
||||||
if !worked {
|
if !worked {
|
||||||
return Err(Error::Uiaa(uiaainfo));
|
return Err(Error::Uiaa(uiaainfo));
|
||||||
}
|
}
|
||||||
// Success!
|
// Success!
|
||||||
} else if let Some(json) = body.json_body {
|
} else if let Some(json) = body.json_body {
|
||||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||||
services()
|
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||||
.uiaa
|
|
||||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
|
||||||
return Err(Error::Uiaa(uiaainfo));
|
return Err(Error::Uiaa(uiaainfo));
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||||
}
|
}
|
||||||
|
|
||||||
services()
|
services().users.remove_device(sender_user, &body.device_id)?;
|
||||||
.users
|
|
||||||
.remove_device(sender_user, &body.device_id)?;
|
|
||||||
|
|
||||||
Ok(delete_device::v3::Response {})
|
Ok(delete_device::v3::Response {})
|
||||||
}
|
}
|
||||||
|
@ -122,12 +110,11 @@ pub async fn delete_device_route(
|
||||||
///
|
///
|
||||||
/// For each device:
|
/// For each device:
|
||||||
/// - Invalidates access token
|
/// - Invalidates access token
|
||||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||||
|
/// last seen ts)
|
||||||
/// - Forgets to-device events
|
/// - Forgets to-device events
|
||||||
/// - Triggers device list updates
|
/// - Triggers device list updates
|
||||||
pub async fn delete_devices_route(
|
pub async fn delete_devices_route(body: Ruma<delete_devices::v3::Request>) -> Result<delete_devices::v3::Response> {
|
||||||
body: Ruma<delete_devices::v3::Request>,
|
|
||||||
) -> Result<delete_devices::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
|
@ -143,19 +130,14 @@ pub async fn delete_devices_route(
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(auth) = &body.auth {
|
if let Some(auth) = &body.auth {
|
||||||
let (worked, uiaainfo) =
|
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||||
services()
|
|
||||||
.uiaa
|
|
||||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
|
||||||
if !worked {
|
if !worked {
|
||||||
return Err(Error::Uiaa(uiaainfo));
|
return Err(Error::Uiaa(uiaainfo));
|
||||||
}
|
}
|
||||||
// Success!
|
// Success!
|
||||||
} else if let Some(json) = body.json_body {
|
} else if let Some(json) = body.json_body {
|
||||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||||
services()
|
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||||
.uiaa
|
|
||||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
|
||||||
return Err(Error::Uiaa(uiaainfo));
|
return Err(Error::Uiaa(uiaainfo));
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||||
|
|
|
@ -1,11 +1,7 @@
|
||||||
use crate::{services, Error, Result, Ruma};
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::{
|
api::{
|
||||||
client::{
|
client::{
|
||||||
directory::{
|
directory::{get_public_rooms, get_public_rooms_filtered, get_room_visibility, set_room_visibility},
|
||||||
get_public_rooms, get_public_rooms_filtered, get_room_visibility,
|
|
||||||
set_room_visibility,
|
|
||||||
},
|
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
room,
|
room,
|
||||||
},
|
},
|
||||||
|
@ -28,6 +24,8 @@ use ruma::{
|
||||||
};
|
};
|
||||||
use tracing::{error, info, warn};
|
use tracing::{error, info, warn};
|
||||||
|
|
||||||
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `POST /_matrix/client/v3/publicRooms`
|
/// # `POST /_matrix/client/v3/publicRooms`
|
||||||
///
|
///
|
||||||
/// Lists the public rooms on this server.
|
/// Lists the public rooms on this server.
|
||||||
|
@ -36,11 +34,7 @@ use tracing::{error, info, warn};
|
||||||
pub async fn get_public_rooms_filtered_route(
|
pub async fn get_public_rooms_filtered_route(
|
||||||
body: Ruma<get_public_rooms_filtered::v3::Request>,
|
body: Ruma<get_public_rooms_filtered::v3::Request>,
|
||||||
) -> Result<get_public_rooms_filtered::v3::Response> {
|
) -> Result<get_public_rooms_filtered::v3::Response> {
|
||||||
if !services()
|
if !services().globals.config.allow_public_room_directory_without_auth {
|
||||||
.globals
|
|
||||||
.config
|
|
||||||
.allow_public_room_directory_without_auth
|
|
||||||
{
|
|
||||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -62,11 +56,7 @@ pub async fn get_public_rooms_filtered_route(
|
||||||
pub async fn get_public_rooms_route(
|
pub async fn get_public_rooms_route(
|
||||||
body: Ruma<get_public_rooms::v3::Request>,
|
body: Ruma<get_public_rooms::v3::Request>,
|
||||||
) -> Result<get_public_rooms::v3::Response> {
|
) -> Result<get_public_rooms::v3::Response> {
|
||||||
if !services()
|
if !services().globals.config.allow_public_room_directory_without_auth {
|
||||||
.globals
|
|
||||||
.config
|
|
||||||
.allow_public_room_directory_without_auth
|
|
||||||
{
|
|
||||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -106,14 +96,14 @@ pub async fn set_room_visibility_route(
|
||||||
room::Visibility::Public => {
|
room::Visibility::Public => {
|
||||||
services().rooms.directory.set_public(&body.room_id)?;
|
services().rooms.directory.set_public(&body.room_id)?;
|
||||||
info!("{} made {} public", sender_user, body.room_id);
|
info!("{} made {} public", sender_user, body.room_id);
|
||||||
}
|
},
|
||||||
room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?,
|
room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?,
|
||||||
_ => {
|
_ => {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Room visibility type is not supported.",
|
"Room visibility type is not supported.",
|
||||||
));
|
));
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(set_room_visibility::v3::Response {})
|
Ok(set_room_visibility::v3::Response {})
|
||||||
|
@ -140,15 +130,9 @@ pub async fn get_room_visibility_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get_public_rooms_filtered_helper(
|
pub(crate) async fn get_public_rooms_filtered_helper(
|
||||||
server: Option<&ServerName>,
|
server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, filter: &Filter, _network: &RoomNetwork,
|
||||||
limit: Option<UInt>,
|
|
||||||
since: Option<&str>,
|
|
||||||
filter: &Filter,
|
|
||||||
_network: &RoomNetwork,
|
|
||||||
) -> Result<get_public_rooms_filtered::v3::Response> {
|
) -> Result<get_public_rooms_filtered::v3::Response> {
|
||||||
if let Some(other_server) =
|
if let Some(other_server) = server.filter(|server| *server != services().globals.server_name().as_str()) {
|
||||||
server.filter(|server| *server != services().globals.server_name().as_str())
|
|
||||||
{
|
|
||||||
let response = services()
|
let response = services()
|
||||||
.sending
|
.sending
|
||||||
.send_federation_request(
|
.send_federation_request(
|
||||||
|
@ -181,12 +165,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
||||||
let backwards = match characters.next() {
|
let backwards = match characters.next() {
|
||||||
Some('n') => false,
|
Some('n') => false,
|
||||||
Some('p') => true,
|
Some('p') => true,
|
||||||
_ => {
|
_ => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token")),
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Invalid `since` token",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
num_since = characters
|
num_since = characters
|
||||||
|
@ -214,9 +193,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
||||||
.map_or(Ok(None), |s| {
|
.map_or(Ok(None), |s| {
|
||||||
serde_json::from_str(s.content.get())
|
serde_json::from_str(s.content.get())
|
||||||
.map(|c: RoomCanonicalAliasEventContent| c.alias)
|
.map(|c: RoomCanonicalAliasEventContent| c.alias)
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::bad_database("Invalid canonical alias event in database."))
|
||||||
Error::bad_database("Invalid canonical alias event in database.")
|
|
||||||
})
|
|
||||||
})?,
|
})?,
|
||||||
name: services().rooms.state_accessor.get_name(&room_id)?,
|
name: services().rooms.state_accessor.get_name(&room_id)?,
|
||||||
num_joined_members: services()
|
num_joined_members: services()
|
||||||
|
@ -251,11 +228,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
||||||
.map(|c: RoomHistoryVisibilityEventContent| {
|
.map(|c: RoomHistoryVisibilityEventContent| {
|
||||||
c.history_visibility == HistoryVisibility::WorldReadable
|
c.history_visibility == HistoryVisibility::WorldReadable
|
||||||
})
|
})
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::bad_database("Invalid room history visibility event in database."))
|
||||||
Error::bad_database(
|
|
||||||
"Invalid room history visibility event in database.",
|
|
||||||
)
|
|
||||||
})
|
|
||||||
})?,
|
})?,
|
||||||
guest_can_join: services()
|
guest_can_join: services()
|
||||||
.rooms
|
.rooms
|
||||||
|
@ -263,12 +236,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
||||||
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")?
|
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")?
|
||||||
.map_or(Ok(false), |s| {
|
.map_or(Ok(false), |s| {
|
||||||
serde_json::from_str(s.content.get())
|
serde_json::from_str(s.content.get())
|
||||||
.map(|c: RoomGuestAccessEventContent| {
|
.map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin)
|
||||||
c.guest_access == GuestAccess::CanJoin
|
.map_err(|_| Error::bad_database("Invalid room guest access event in database."))
|
||||||
})
|
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database("Invalid room guest access event in database.")
|
|
||||||
})
|
|
||||||
})?,
|
})?,
|
||||||
avatar_url: services()
|
avatar_url: services()
|
||||||
.rooms
|
.rooms
|
||||||
|
@ -277,9 +246,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
||||||
.map(|s| {
|
.map(|s| {
|
||||||
serde_json::from_str(s.content.get())
|
serde_json::from_str(s.content.get())
|
||||||
.map(|c: RoomAvatarEventContent| c.url)
|
.map(|c: RoomAvatarEventContent| c.url)
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::bad_database("Invalid room avatar event in database."))
|
||||||
Error::bad_database("Invalid room avatar event in database.")
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.transpose()?
|
.transpose()?
|
||||||
// url is now an Option<String> so we must flatten
|
// url is now an Option<String> so we must flatten
|
||||||
|
@ -308,12 +275,10 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
||||||
.state_accessor
|
.state_accessor
|
||||||
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
|
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
|
||||||
.map(|s| {
|
.map(|s| {
|
||||||
serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err(
|
serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err(|e| {
|
||||||
|e| {
|
|
||||||
error!("Invalid room create event in database: {}", e);
|
error!("Invalid room create event in database: {}", e);
|
||||||
Error::BadDatabase("Invalid room create event in database.")
|
Error::BadDatabase("Invalid room create event in database.")
|
||||||
},
|
})
|
||||||
)
|
|
||||||
})
|
})
|
||||||
.transpose()?
|
.transpose()?
|
||||||
.and_then(|e| e.room_type),
|
.and_then(|e| e.room_type),
|
||||||
|
@ -323,11 +288,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
||||||
})
|
})
|
||||||
.filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
|
.filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
|
||||||
.filter(|chunk| {
|
.filter(|chunk| {
|
||||||
if let Some(query) = filter
|
if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) {
|
||||||
.generic_search_term
|
|
||||||
.as_ref()
|
|
||||||
.map(|q| q.to_lowercase())
|
|
||||||
{
|
|
||||||
if let Some(name) = &chunk.name {
|
if let Some(name) = &chunk.name {
|
||||||
if name.as_str().to_lowercase().contains(&query) {
|
if name.as_str().to_lowercase().contains(&query) {
|
||||||
return true;
|
return true;
|
||||||
|
@ -359,11 +320,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
||||||
|
|
||||||
let total_room_count_estimate = (all_rooms.len() as u32).into();
|
let total_room_count_estimate = (all_rooms.len() as u32).into();
|
||||||
|
|
||||||
let chunk: Vec<_> = all_rooms
|
let chunk: Vec<_> = all_rooms.into_iter().skip(num_since as usize).take(limit as usize).collect();
|
||||||
.into_iter()
|
|
||||||
.skip(num_since as usize)
|
|
||||||
.take(limit as usize)
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
let prev_batch = if num_since == 0 {
|
let prev_batch = if num_since == 0 {
|
||||||
None
|
None
|
||||||
|
|
|
@ -1,17 +1,16 @@
|
||||||
use crate::{services, Error, Result, Ruma};
|
|
||||||
use ruma::api::client::{
|
use ruma::api::client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
filter::{create_filter, get_filter},
|
filter::{create_filter, get_filter},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
|
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
|
||||||
///
|
///
|
||||||
/// Loads a filter that was previously created.
|
/// Loads a filter that was previously created.
|
||||||
///
|
///
|
||||||
/// - A user can only access their own filters
|
/// - A user can only access their own filters
|
||||||
pub async fn get_filter_route(
|
pub async fn get_filter_route(body: Ruma<get_filter::v3::Request>) -> Result<get_filter::v3::Response> {
|
||||||
body: Ruma<get_filter::v3::Request>,
|
|
||||||
) -> Result<get_filter::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let filter = match services().users.get_filter(sender_user, &body.filter_id)? {
|
let filter = match services().users.get_filter(sender_user, &body.filter_id)? {
|
||||||
Some(filter) => filter,
|
Some(filter) => filter,
|
||||||
|
@ -24,9 +23,7 @@ pub async fn get_filter_route(
|
||||||
/// # `PUT /_matrix/client/r0/user/{userId}/filter`
|
/// # `PUT /_matrix/client/r0/user/{userId}/filter`
|
||||||
///
|
///
|
||||||
/// Creates a new filter to be used by other endpoints.
|
/// Creates a new filter to be used by other endpoints.
|
||||||
pub async fn create_filter_route(
|
pub async fn create_filter_route(body: Ruma<create_filter::v3::Request>) -> Result<create_filter::v3::Response> {
|
||||||
body: Ruma<create_filter::v3::Request>,
|
|
||||||
) -> Result<create_filter::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
Ok(create_filter::v3::Response::new(
|
Ok(create_filter::v3::Response::new(
|
||||||
services().users.create_filter(sender_user, &body.filter)?,
|
services().users.create_filter(sender_user, &body.filter)?,
|
||||||
|
|
|
@ -1,14 +1,14 @@
|
||||||
use super::SESSION_ID_LENGTH;
|
use std::{
|
||||||
use crate::{services, utils, Error, Result, Ruma};
|
collections::{hash_map, BTreeMap, HashMap, HashSet},
|
||||||
|
time::{Duration, Instant},
|
||||||
|
};
|
||||||
|
|
||||||
use futures_util::{stream::FuturesUnordered, StreamExt};
|
use futures_util::{stream::FuturesUnordered, StreamExt};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::{
|
api::{
|
||||||
client::{
|
client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
keys::{
|
keys::{claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, upload_signing_keys},
|
||||||
claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures,
|
|
||||||
upload_signing_keys,
|
|
||||||
},
|
|
||||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||||
},
|
},
|
||||||
federation,
|
federation,
|
||||||
|
@ -17,48 +17,36 @@ use ruma::{
|
||||||
DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId,
|
DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId,
|
||||||
};
|
};
|
||||||
use serde_json::json;
|
use serde_json::json;
|
||||||
use std::{
|
|
||||||
collections::{hash_map, BTreeMap, HashMap, HashSet},
|
|
||||||
time::{Duration, Instant},
|
|
||||||
};
|
|
||||||
use tracing::{debug, error};
|
use tracing::{debug, error};
|
||||||
|
|
||||||
|
use super::SESSION_ID_LENGTH;
|
||||||
|
use crate::{services, utils, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `POST /_matrix/client/r0/keys/upload`
|
/// # `POST /_matrix/client/r0/keys/upload`
|
||||||
///
|
///
|
||||||
/// Publish end-to-end encryption keys for the sender device.
|
/// Publish end-to-end encryption keys for the sender device.
|
||||||
///
|
///
|
||||||
/// - Adds one time keys
|
/// - Adds one time keys
|
||||||
/// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?)
|
/// - If there are no device keys yet: Adds device keys (TODO: merge with
|
||||||
pub async fn upload_keys_route(
|
/// existing keys?)
|
||||||
body: Ruma<upload_keys::v3::Request>,
|
pub async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> Result<upload_keys::v3::Response> {
|
||||||
) -> Result<upload_keys::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
for (key_key, key_value) in &body.one_time_keys {
|
for (key_key, key_value) in &body.one_time_keys {
|
||||||
services()
|
services().users.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
|
||||||
.users
|
|
||||||
.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(device_keys) = &body.device_keys {
|
if let Some(device_keys) = &body.device_keys {
|
||||||
// TODO: merge this and the existing event?
|
// TODO: merge this and the existing event?
|
||||||
// This check is needed to assure that signatures are kept
|
// This check is needed to assure that signatures are kept
|
||||||
if services()
|
if services().users.get_device_keys(sender_user, sender_device)?.is_none() {
|
||||||
.users
|
services().users.add_device_keys(sender_user, sender_device, device_keys)?;
|
||||||
.get_device_keys(sender_user, sender_device)?
|
|
||||||
.is_none()
|
|
||||||
{
|
|
||||||
services()
|
|
||||||
.users
|
|
||||||
.add_device_keys(sender_user, sender_device, device_keys)?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(upload_keys::v3::Response {
|
Ok(upload_keys::v3::Response {
|
||||||
one_time_key_counts: services()
|
one_time_key_counts: services().users.count_one_time_keys(sender_user, sender_device)?,
|
||||||
.users
|
|
||||||
.count_one_time_keys(sender_user, sender_device)?,
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -68,7 +56,8 @@ pub async fn upload_keys_route(
|
||||||
///
|
///
|
||||||
/// - Always fetches users from other servers over federation
|
/// - Always fetches users from other servers over federation
|
||||||
/// - Gets master keys, self-signing keys, user signing keys and device keys.
|
/// - Gets master keys, self-signing keys, user signing keys and device keys.
|
||||||
/// - The master and self-signing keys contain signatures that the user is allowed to see
|
/// - The master and self-signing keys contain signatures that the user is
|
||||||
|
/// allowed to see
|
||||||
pub async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_keys::v3::Response> {
|
pub async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_keys::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
|
@ -86,9 +75,7 @@ pub async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_key
|
||||||
/// # `POST /_matrix/client/r0/keys/claim`
|
/// # `POST /_matrix/client/r0/keys/claim`
|
||||||
///
|
///
|
||||||
/// Claims one-time keys
|
/// Claims one-time keys
|
||||||
pub async fn claim_keys_route(
|
pub async fn claim_keys_route(body: Ruma<claim_keys::v3::Request>) -> Result<claim_keys::v3::Response> {
|
||||||
body: Ruma<claim_keys::v3::Request>,
|
|
||||||
) -> Result<claim_keys::v3::Response> {
|
|
||||||
let response = claim_keys_helper(&body.one_time_keys).await?;
|
let response = claim_keys_helper(&body.one_time_keys).await?;
|
||||||
|
|
||||||
Ok(response)
|
Ok(response)
|
||||||
|
@ -117,19 +104,14 @@ pub async fn upload_signing_keys_route(
|
||||||
};
|
};
|
||||||
|
|
||||||
if let Some(auth) = &body.auth {
|
if let Some(auth) = &body.auth {
|
||||||
let (worked, uiaainfo) =
|
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||||
services()
|
|
||||||
.uiaa
|
|
||||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
|
||||||
if !worked {
|
if !worked {
|
||||||
return Err(Error::Uiaa(uiaainfo));
|
return Err(Error::Uiaa(uiaainfo));
|
||||||
}
|
}
|
||||||
// Success!
|
// Success!
|
||||||
} else if let Some(json) = body.json_body {
|
} else if let Some(json) = body.json_body {
|
||||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||||
services()
|
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||||
.uiaa
|
|
||||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
|
||||||
return Err(Error::Uiaa(uiaainfo));
|
return Err(Error::Uiaa(uiaainfo));
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||||
|
@ -163,20 +145,11 @@ pub async fn upload_signatures_route(
|
||||||
|
|
||||||
for signature in key
|
for signature in key
|
||||||
.get("signatures")
|
.get("signatures")
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Missing signatures field."))?
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Missing signatures field.",
|
|
||||||
))?
|
|
||||||
.get(sender_user.to_string())
|
.get(sender_user.to_string())
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid user in signatures field."))?
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Invalid user in signatures field.",
|
|
||||||
))?
|
|
||||||
.as_object()
|
.as_object()
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature."))?
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Invalid signature.",
|
|
||||||
))?
|
|
||||||
.clone()
|
.clone()
|
||||||
.into_iter()
|
.into_iter()
|
||||||
{
|
{
|
||||||
|
@ -186,15 +159,10 @@ pub async fn upload_signatures_route(
|
||||||
signature
|
signature
|
||||||
.1
|
.1
|
||||||
.as_str()
|
.as_str()
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))?
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Invalid signature value.",
|
|
||||||
))?
|
|
||||||
.to_owned(),
|
.to_owned(),
|
||||||
);
|
);
|
||||||
services()
|
services().users.sign_key(user_id, key_id, signature, sender_user)?;
|
||||||
.users
|
|
||||||
.sign_key(user_id, key_id, signature, sender_user)?;
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -206,12 +174,11 @@ pub async fn upload_signatures_route(
|
||||||
|
|
||||||
/// # `POST /_matrix/client/r0/keys/changes`
|
/// # `POST /_matrix/client/r0/keys/changes`
|
||||||
///
|
///
|
||||||
/// Gets a list of users who have updated their device identity keys since the previous sync token.
|
/// Gets a list of users who have updated their device identity keys since the
|
||||||
|
/// previous sync token.
|
||||||
///
|
///
|
||||||
/// - TODO: left users
|
/// - TODO: left users
|
||||||
pub async fn get_key_changes_route(
|
pub async fn get_key_changes_route(body: Ruma<get_key_changes::v3::Request>) -> Result<get_key_changes::v3::Response> {
|
||||||
body: Ruma<get_key_changes::v3::Request>,
|
|
||||||
) -> Result<get_key_changes::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let mut device_list_updates = HashSet::new();
|
let mut device_list_updates = HashSet::new();
|
||||||
|
@ -221,35 +188,20 @@ pub async fn get_key_changes_route(
|
||||||
.users
|
.users
|
||||||
.keys_changed(
|
.keys_changed(
|
||||||
sender_user.as_str(),
|
sender_user.as_str(),
|
||||||
body.from
|
body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
|
||||||
.parse()
|
Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?),
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
|
|
||||||
Some(
|
|
||||||
body.to
|
|
||||||
.parse()
|
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
|
|
||||||
),
|
|
||||||
)
|
)
|
||||||
.filter_map(std::result::Result::ok),
|
.filter_map(std::result::Result::ok),
|
||||||
);
|
);
|
||||||
|
|
||||||
for room_id in services()
|
for room_id in services().rooms.state_cache.rooms_joined(sender_user).filter_map(std::result::Result::ok) {
|
||||||
.rooms
|
|
||||||
.state_cache
|
|
||||||
.rooms_joined(sender_user)
|
|
||||||
.filter_map(std::result::Result::ok)
|
|
||||||
{
|
|
||||||
device_list_updates.extend(
|
device_list_updates.extend(
|
||||||
services()
|
services()
|
||||||
.users
|
.users
|
||||||
.keys_changed(
|
.keys_changed(
|
||||||
room_id.as_ref(),
|
room_id.as_ref(),
|
||||||
body.from.parse().map_err(|_| {
|
body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
|
||||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`.")
|
Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?),
|
||||||
})?,
|
|
||||||
Some(body.to.parse().map_err(|_| {
|
|
||||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.")
|
|
||||||
})?),
|
|
||||||
)
|
)
|
||||||
.filter_map(std::result::Result::ok),
|
.filter_map(std::result::Result::ok),
|
||||||
);
|
);
|
||||||
|
@ -261,9 +213,7 @@ pub async fn get_key_changes_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||||
sender_user: Option<&UserId>,
|
sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, allowed_signatures: F,
|
||||||
device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
|
|
||||||
allowed_signatures: F,
|
|
||||||
include_display_names: bool,
|
include_display_names: bool,
|
||||||
) -> Result<get_keys::v3::Response> {
|
) -> Result<get_keys::v3::Response> {
|
||||||
let mut master_keys = BTreeMap::new();
|
let mut master_keys = BTreeMap::new();
|
||||||
|
@ -277,10 +227,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||||
let user_id: &UserId = user_id;
|
let user_id: &UserId = user_id;
|
||||||
|
|
||||||
if user_id.server_name() != services().globals.server_name() {
|
if user_id.server_name() != services().globals.server_name() {
|
||||||
get_over_federation
|
get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, device_ids));
|
||||||
.entry(user_id.server_name())
|
|
||||||
.or_insert_with(Vec::new)
|
|
||||||
.push((user_id, device_ids));
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -292,9 +239,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||||
let metadata = services()
|
let metadata = services()
|
||||||
.users
|
.users
|
||||||
.get_device_metadata(user_id, &device_id)?
|
.get_device_metadata(user_id, &device_id)?
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?;
|
||||||
Error::bad_database("all_device_keys contained nonexistent device.")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
|
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
|
||||||
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
||||||
|
@ -307,13 +252,9 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||||
for device_id in device_ids {
|
for device_id in device_ids {
|
||||||
let mut container = BTreeMap::new();
|
let mut container = BTreeMap::new();
|
||||||
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? {
|
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? {
|
||||||
let metadata = services()
|
let metadata = services().users.get_device_metadata(user_id, device_id)?.ok_or(
|
||||||
.users
|
Error::BadRequest(ErrorKind::InvalidParam, "Tried to get keys for nonexistent device."),
|
||||||
.get_device_metadata(user_id, device_id)?
|
)?;
|
||||||
.ok_or(Error::BadRequest(
|
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Tried to get keys for nonexistent device.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
|
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
|
||||||
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
||||||
|
@ -323,17 +264,11 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(master_key) =
|
if let Some(master_key) = services().users.get_master_key(sender_user, user_id, &allowed_signatures)? {
|
||||||
services()
|
|
||||||
.users
|
|
||||||
.get_master_key(sender_user, user_id, &allowed_signatures)?
|
|
||||||
{
|
|
||||||
master_keys.insert(user_id.to_owned(), master_key);
|
master_keys.insert(user_id.to_owned(), master_key);
|
||||||
}
|
}
|
||||||
if let Some(self_signing_key) =
|
if let Some(self_signing_key) =
|
||||||
services()
|
services().users.get_self_signing_key(sender_user, user_id, &allowed_signatures)?
|
||||||
.users
|
|
||||||
.get_self_signing_key(sender_user, user_id, &allowed_signatures)?
|
|
||||||
{
|
{
|
||||||
self_signing_keys.insert(user_id.to_owned(), self_signing_key);
|
self_signing_keys.insert(user_id.to_owned(), self_signing_key);
|
||||||
}
|
}
|
||||||
|
@ -346,29 +281,17 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||||
|
|
||||||
let mut failures = BTreeMap::new();
|
let mut failures = BTreeMap::new();
|
||||||
|
|
||||||
let back_off = |id| match services()
|
let back_off = |id| match services().globals.bad_query_ratelimiter.write().unwrap().entry(id) {
|
||||||
.globals
|
|
||||||
.bad_query_ratelimiter
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.entry(id)
|
|
||||||
{
|
|
||||||
hash_map::Entry::Vacant(e) => {
|
hash_map::Entry::Vacant(e) => {
|
||||||
e.insert((Instant::now(), 1));
|
e.insert((Instant::now(), 1));
|
||||||
}
|
},
|
||||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut futures: FuturesUnordered<_> = get_over_federation
|
let mut futures: FuturesUnordered<_> = get_over_federation
|
||||||
.into_iter()
|
.into_iter()
|
||||||
.map(|(server, vec)| async move {
|
.map(|(server, vec)| async move {
|
||||||
if let Some((time, tries)) = services()
|
if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().unwrap().get(server) {
|
||||||
.globals
|
|
||||||
.bad_query_ratelimiter
|
|
||||||
.read()
|
|
||||||
.unwrap()
|
|
||||||
.get(server)
|
|
||||||
{
|
|
||||||
// Exponential backoff
|
// Exponential backoff
|
||||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||||
|
@ -377,10 +300,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||||
|
|
||||||
if time.elapsed() < min_elapsed_duration {
|
if time.elapsed() < min_elapsed_duration {
|
||||||
debug!("Backing off query from {:?}", server);
|
debug!("Backing off query from {:?}", server);
|
||||||
return (
|
return (server, Err(Error::BadServerResponse("bad query, still backing off")));
|
||||||
server,
|
|
||||||
Err(Error::BadServerResponse("bad query, still backing off")),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -412,35 +332,31 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||||
match response {
|
match response {
|
||||||
Ok(Ok(response)) => {
|
Ok(Ok(response)) => {
|
||||||
for (user, masterkey) in response.master_keys {
|
for (user, masterkey) in response.master_keys {
|
||||||
let (master_key_id, mut master_key) =
|
let (master_key_id, mut master_key) = services().users.parse_master_key(&user, &masterkey)?;
|
||||||
services().users.parse_master_key(&user, &masterkey)?;
|
|
||||||
|
|
||||||
if let Some(our_master_key) = services().users.get_key(
|
if let Some(our_master_key) =
|
||||||
&master_key_id,
|
services().users.get_key(&master_key_id, sender_user, &user, &allowed_signatures)?
|
||||||
sender_user,
|
{
|
||||||
&user,
|
let (_, our_master_key) = services().users.parse_master_key(&user, &our_master_key)?;
|
||||||
&allowed_signatures,
|
|
||||||
)? {
|
|
||||||
let (_, our_master_key) =
|
|
||||||
services().users.parse_master_key(&user, &our_master_key)?;
|
|
||||||
master_key.signatures.extend(our_master_key.signatures);
|
master_key.signatures.extend(our_master_key.signatures);
|
||||||
}
|
}
|
||||||
let json = serde_json::to_value(master_key).expect("to_value always works");
|
let json = serde_json::to_value(master_key).expect("to_value always works");
|
||||||
let raw = serde_json::from_value(json).expect("Raw::from_value always works");
|
let raw = serde_json::from_value(json).expect("Raw::from_value always works");
|
||||||
services().users.add_cross_signing_keys(
|
services().users.add_cross_signing_keys(
|
||||||
&user, &raw, &None, &None,
|
&user, &raw, &None, &None,
|
||||||
false, // Dont notify. A notification would trigger another key request resulting in an endless loop
|
false, /* Dont notify. A notification would trigger another key request resulting in an
|
||||||
|
* endless loop */
|
||||||
)?;
|
)?;
|
||||||
master_keys.insert(user, raw);
|
master_keys.insert(user, raw);
|
||||||
}
|
}
|
||||||
|
|
||||||
self_signing_keys.extend(response.self_signing_keys);
|
self_signing_keys.extend(response.self_signing_keys);
|
||||||
device_keys.extend(response.device_keys);
|
device_keys.extend(response.device_keys);
|
||||||
}
|
},
|
||||||
_ => {
|
_ => {
|
||||||
back_off(server.to_owned());
|
back_off(server.to_owned());
|
||||||
failures.insert(server.to_string(), json!({}));
|
failures.insert(server.to_string(), json!({}));
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -454,8 +370,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_unsigned_device_display_name(
|
fn add_unsigned_device_display_name(
|
||||||
keys: &mut Raw<ruma::encryption::DeviceKeys>,
|
keys: &mut Raw<ruma::encryption::DeviceKeys>, metadata: ruma::api::client::device::Device,
|
||||||
metadata: ruma::api::client::device::Device,
|
|
||||||
include_display_names: bool,
|
include_display_names: bool,
|
||||||
) -> serde_json::Result<()> {
|
) -> serde_json::Result<()> {
|
||||||
if let Some(display_name) = metadata.display_name {
|
if let Some(display_name) = metadata.display_name {
|
||||||
|
@ -488,19 +403,12 @@ pub(crate) async fn claim_keys_helper(
|
||||||
|
|
||||||
for (user_id, map) in one_time_keys_input {
|
for (user_id, map) in one_time_keys_input {
|
||||||
if user_id.server_name() != services().globals.server_name() {
|
if user_id.server_name() != services().globals.server_name() {
|
||||||
get_over_federation
|
get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, map));
|
||||||
.entry(user_id.server_name())
|
|
||||||
.or_insert_with(Vec::new)
|
|
||||||
.push((user_id, map));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut container = BTreeMap::new();
|
let mut container = BTreeMap::new();
|
||||||
for (device_id, key_algorithm) in map {
|
for (device_id, key_algorithm) in map {
|
||||||
if let Some(one_time_keys) =
|
if let Some(one_time_keys) = services().users.take_one_time_key(user_id, device_id, key_algorithm)? {
|
||||||
services()
|
|
||||||
.users
|
|
||||||
.take_one_time_key(user_id, device_id, key_algorithm)?
|
|
||||||
{
|
|
||||||
let mut c = BTreeMap::new();
|
let mut c = BTreeMap::new();
|
||||||
c.insert(one_time_keys.0, one_time_keys.1);
|
c.insert(one_time_keys.0, one_time_keys.1);
|
||||||
container.insert(device_id.clone(), c);
|
container.insert(device_id.clone(), c);
|
||||||
|
@ -537,10 +445,10 @@ pub(crate) async fn claim_keys_helper(
|
||||||
match response {
|
match response {
|
||||||
Ok(keys) => {
|
Ok(keys) => {
|
||||||
one_time_keys.extend(keys.one_time_keys);
|
one_time_keys.extend(keys.one_time_keys);
|
||||||
}
|
},
|
||||||
Err(_e) => {
|
Err(_e) => {
|
||||||
failures.insert(server.to_string(), json!({}));
|
failures.insert(server.to_string(), json!({}));
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,21 +1,21 @@
|
||||||
use std::{io::Cursor, net::IpAddr, sync::Arc, time::Duration};
|
use std::{io::Cursor, net::IpAddr, sync::Arc, time::Duration};
|
||||||
|
|
||||||
|
use image::io::Reader as ImgReader;
|
||||||
|
use reqwest::Url;
|
||||||
|
use ruma::api::client::{
|
||||||
|
error::ErrorKind,
|
||||||
|
media::{
|
||||||
|
create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
|
||||||
|
get_media_preview,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
use tracing::{debug, error, info, warn};
|
||||||
|
use webpage::HTML;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
service::media::{FileMeta, UrlPreviewData},
|
service::media::{FileMeta, UrlPreviewData},
|
||||||
services, utils, Error, Result, Ruma,
|
services, utils, Error, Result, Ruma,
|
||||||
};
|
};
|
||||||
use image::io::Reader as ImgReader;
|
|
||||||
|
|
||||||
use reqwest::Url;
|
|
||||||
use ruma::api::client::{
|
|
||||||
error::ErrorKind,
|
|
||||||
media::{
|
|
||||||
create_content, get_content, get_content_as_filename, get_content_thumbnail,
|
|
||||||
get_media_config, get_media_preview,
|
|
||||||
},
|
|
||||||
};
|
|
||||||
use tracing::{debug, error, info, warn};
|
|
||||||
use webpage::HTML;
|
|
||||||
|
|
||||||
/// generated MXC ID (`media-id`) length
|
/// generated MXC ID (`media-id`) length
|
||||||
const MXC_LENGTH: usize = 32;
|
const MXC_LENGTH: usize = 32;
|
||||||
|
@ -39,22 +39,13 @@ pub async fn get_media_preview_route(
|
||||||
) -> Result<get_media_preview::v3::Response> {
|
) -> Result<get_media_preview::v3::Response> {
|
||||||
let url = &body.url;
|
let url = &body.url;
|
||||||
if !url_preview_allowed(url) {
|
if !url_preview_allowed(url) {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "URL is not allowed to be previewed"));
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"URL is not allowed to be previewed",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Ok(preview) = get_url_preview(url).await {
|
if let Ok(preview) = get_url_preview(url).await {
|
||||||
let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
|
let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
|
||||||
error!(
|
error!("Failed to convert UrlPreviewData into a serde json value: {}", e);
|
||||||
"Failed to convert UrlPreviewData into a serde json value: {}",
|
Error::BadRequest(ErrorKind::Unknown, "Unknown error occurred parsing URL preview")
|
||||||
e
|
|
||||||
);
|
|
||||||
Error::BadRequest(
|
|
||||||
ErrorKind::Unknown,
|
|
||||||
"Unknown error occurred parsing URL preview",
|
|
||||||
)
|
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
return Ok(get_media_preview::v3::Response::from_raw_value(res));
|
return Ok(get_media_preview::v3::Response::from_raw_value(res));
|
||||||
|
@ -74,9 +65,7 @@ pub async fn get_media_preview_route(
|
||||||
///
|
///
|
||||||
/// - Some metadata will be saved in the database
|
/// - Some metadata will be saved in the database
|
||||||
/// - Media will be saved in the media/ directory
|
/// - Media will be saved in the media/ directory
|
||||||
pub async fn create_content_route(
|
pub async fn create_content_route(body: Ruma<create_content::v3::Request>) -> Result<create_content::v3::Response> {
|
||||||
body: Ruma<create_content::v3::Request>,
|
|
||||||
) -> Result<create_content::v3::Response> {
|
|
||||||
let mxc = format!(
|
let mxc = format!(
|
||||||
"mxc://{}/{}",
|
"mxc://{}/{}",
|
||||||
services().globals.server_name(),
|
services().globals.server_name(),
|
||||||
|
@ -87,10 +76,7 @@ pub async fn create_content_route(
|
||||||
.media
|
.media
|
||||||
.create(
|
.create(
|
||||||
mxc.clone(),
|
mxc.clone(),
|
||||||
body.filename
|
body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(),
|
||||||
.as_ref()
|
|
||||||
.map(|filename| "inline; filename=".to_owned() + filename)
|
|
||||||
.as_deref(),
|
|
||||||
body.content_type.as_deref(),
|
body.content_type.as_deref(),
|
||||||
&body.file,
|
&body.file,
|
||||||
)
|
)
|
||||||
|
@ -106,20 +92,15 @@ pub async fn create_content_route(
|
||||||
|
|
||||||
/// helper method to fetch remote media from other servers over federation
|
/// helper method to fetch remote media from other servers over federation
|
||||||
pub async fn get_remote_content(
|
pub async fn get_remote_content(
|
||||||
mxc: &str,
|
mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool, timeout_ms: Duration,
|
||||||
server_name: &ruma::ServerName,
|
|
||||||
media_id: String,
|
|
||||||
allow_redirect: bool,
|
|
||||||
timeout_ms: Duration,
|
|
||||||
) -> Result<get_content::v3::Response, Error> {
|
) -> Result<get_content::v3::Response, Error> {
|
||||||
// we'll lie to the client and say the blocked server's media was not found and log.
|
// we'll lie to the client and say the blocked server's media was not found and
|
||||||
// the client has no way of telling anyways so this is a security bonus.
|
// log. the client has no way of telling anyways so this is a security bonus.
|
||||||
if services()
|
if services().globals.prevent_media_downloads_from().contains(&server_name.to_owned()) {
|
||||||
.globals
|
info!(
|
||||||
.prevent_media_downloads_from()
|
"Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
|
||||||
.contains(&server_name.to_owned())
|
mxc
|
||||||
{
|
);
|
||||||
info!("Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc);
|
|
||||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -156,10 +137,9 @@ pub async fn get_remote_content(
|
||||||
///
|
///
|
||||||
/// - Only allows federation if `allow_remote` is true
|
/// - Only allows federation if `allow_remote` is true
|
||||||
/// - Only redirects if `allow_redirect` is true
|
/// - Only redirects if `allow_redirect` is true
|
||||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds
|
/// - Uses client-provided `timeout_ms` if available, else defaults to 20
|
||||||
pub async fn get_content_route(
|
/// seconds
|
||||||
body: Ruma<get_content::v3::Request>,
|
pub async fn get_content_route(body: Ruma<get_content::v3::Request>) -> Result<get_content::v3::Response> {
|
||||||
) -> Result<get_content::v3::Response> {
|
|
||||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||||
|
|
||||||
if let Some(FileMeta {
|
if let Some(FileMeta {
|
||||||
|
@ -195,14 +175,17 @@ pub async fn get_content_route(
|
||||||
///
|
///
|
||||||
/// - Only allows federation if `allow_remote` is true
|
/// - Only allows federation if `allow_remote` is true
|
||||||
/// - Only redirects if `allow_redirect` is true
|
/// - Only redirects if `allow_redirect` is true
|
||||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds
|
/// - Uses client-provided `timeout_ms` if available, else defaults to 20
|
||||||
|
/// seconds
|
||||||
pub async fn get_content_as_filename_route(
|
pub async fn get_content_as_filename_route(
|
||||||
body: Ruma<get_content_as_filename::v3::Request>,
|
body: Ruma<get_content_as_filename::v3::Request>,
|
||||||
) -> Result<get_content_as_filename::v3::Response> {
|
) -> Result<get_content_as_filename::v3::Response> {
|
||||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||||
|
|
||||||
if let Some(FileMeta {
|
if let Some(FileMeta {
|
||||||
content_type, file, ..
|
content_type,
|
||||||
|
file,
|
||||||
|
..
|
||||||
}) = services().media.get(mxc.clone()).await?
|
}) = services().media.get(mxc.clone()).await?
|
||||||
{
|
{
|
||||||
Ok(get_content_as_filename::v3::Response {
|
Ok(get_content_as_filename::v3::Response {
|
||||||
|
@ -238,24 +221,23 @@ pub async fn get_content_as_filename_route(
|
||||||
///
|
///
|
||||||
/// - Only allows federation if `allow_remote` is true
|
/// - Only allows federation if `allow_remote` is true
|
||||||
/// - Only redirects if `allow_redirect` is true
|
/// - Only redirects if `allow_redirect` is true
|
||||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds
|
/// - Uses client-provided `timeout_ms` if available, else defaults to 20
|
||||||
|
/// seconds
|
||||||
pub async fn get_content_thumbnail_route(
|
pub async fn get_content_thumbnail_route(
|
||||||
body: Ruma<get_content_thumbnail::v3::Request>,
|
body: Ruma<get_content_thumbnail::v3::Request>,
|
||||||
) -> Result<get_content_thumbnail::v3::Response> {
|
) -> Result<get_content_thumbnail::v3::Response> {
|
||||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||||
|
|
||||||
if let Some(FileMeta {
|
if let Some(FileMeta {
|
||||||
content_type, file, ..
|
content_type,
|
||||||
|
file,
|
||||||
|
..
|
||||||
}) = services()
|
}) = services()
|
||||||
.media
|
.media
|
||||||
.get_thumbnail(
|
.get_thumbnail(
|
||||||
mxc.clone(),
|
mxc.clone(),
|
||||||
body.width
|
body.width.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
|
||||||
.try_into()
|
body.height.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
|
|
||||||
body.height
|
|
||||||
.try_into()
|
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
|
|
||||||
)
|
)
|
||||||
.await?
|
.await?
|
||||||
{
|
{
|
||||||
|
@ -265,14 +247,13 @@ pub async fn get_content_thumbnail_route(
|
||||||
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
||||||
})
|
})
|
||||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||||
// we'll lie to the client and say the blocked server's media was not found and log.
|
// we'll lie to the client and say the blocked server's media was not found and
|
||||||
// the client has no way of telling anyways so this is a security bonus.
|
// log. the client has no way of telling anyways so this is a security bonus.
|
||||||
if services()
|
if services().globals.prevent_media_downloads_from().contains(&body.server_name.clone()) {
|
||||||
.globals
|
info!(
|
||||||
.prevent_media_downloads_from()
|
"Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
|
||||||
.contains(&body.server_name.clone())
|
mxc
|
||||||
{
|
);
|
||||||
info!("Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc);
|
|
||||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -319,10 +300,7 @@ async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPrevie
|
||||||
utils::random_string(MXC_LENGTH)
|
utils::random_string(MXC_LENGTH)
|
||||||
);
|
);
|
||||||
|
|
||||||
services()
|
services().media.create(mxc.clone(), None, None, &image).await?;
|
||||||
.media
|
|
||||||
.create(mxc.clone(), None, None, &image)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() {
|
let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() {
|
||||||
Err(_) => (None, None),
|
Err(_) => (None, None),
|
||||||
|
@ -348,19 +326,19 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreview
|
||||||
while let Some(chunk) = response.chunk().await? {
|
while let Some(chunk) = response.chunk().await? {
|
||||||
bytes.extend_from_slice(&chunk);
|
bytes.extend_from_slice(&chunk);
|
||||||
if bytes.len() > services().globals.url_preview_max_spider_size() {
|
if bytes.len() > services().globals.url_preview_max_spider_size() {
|
||||||
debug!("Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the response body and assuming our necessary data is in this range.", url, services().globals.url_preview_max_spider_size());
|
debug!(
|
||||||
|
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \
|
||||||
|
response body and assuming our necessary data is in this range.",
|
||||||
|
url,
|
||||||
|
services().globals.url_preview_max_spider_size()
|
||||||
|
);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
let body = String::from_utf8_lossy(&bytes);
|
let body = String::from_utf8_lossy(&bytes);
|
||||||
let html = match HTML::from_string(body.to_string(), Some(url.to_owned())) {
|
let html = match HTML::from_string(body.to_string(), Some(url.to_owned())) {
|
||||||
Ok(html) => html,
|
Ok(html) => html,
|
||||||
Err(_) => {
|
Err(_) => return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to parse HTML")),
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Unknown,
|
|
||||||
"Failed to parse HTML",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut data = match html.opengraph.images.first() {
|
let mut data = match html.opengraph.images.first() {
|
||||||
|
@ -399,7 +377,7 @@ fn url_request_allowed(addr: &IpAddr) -> bool {
|
||||||
|| (ip4.octets()[0] == 198 && (ip4.octets()[1] & 0xfe) == 18) // is_benchmarking()
|
|| (ip4.octets()[0] == 198 && (ip4.octets()[1] & 0xfe) == 18) // is_benchmarking()
|
||||||
|| (ip4.octets()[0] & 240 == 240 && !ip4.is_broadcast()) // is_reserved()
|
|| (ip4.octets()[0] & 240 == 240 && !ip4.is_broadcast()) // is_reserved()
|
||||||
|| ip4.is_broadcast())
|
|| ip4.is_broadcast())
|
||||||
}
|
},
|
||||||
IpAddr::V6(ip6) => {
|
IpAddr::V6(ip6) => {
|
||||||
!(ip6.is_unspecified()
|
!(ip6.is_unspecified()
|
||||||
|| ip6.is_loopback()
|
|| ip6.is_loopback()
|
||||||
|
@ -426,7 +404,7 @@ fn url_request_allowed(addr: &IpAddr) -> bool {
|
||||||
|| ((ip6.segments()[0] == 0x2001) && (ip6.segments()[1] == 0xdb8)) // is_documentation()
|
|| ((ip6.segments()[0] == 0x2001) && (ip6.segments()[1] == 0xdb8)) // is_documentation()
|
||||||
|| ((ip6.segments()[0] & 0xfe00) == 0xfc00) // is_unique_local()
|
|| ((ip6.segments()[0] & 0xfe00) == 0xfc00) // is_unique_local()
|
||||||
|| ((ip6.segments()[0] & 0xffc0) == 0xfe80)) // is_unicast_link_local
|
|| ((ip6.segments()[0] & 0xffc0) == 0xfe80)) // is_unicast_link_local
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -434,38 +412,21 @@ async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
|
||||||
let client = services().globals.url_preview_client();
|
let client = services().globals.url_preview_client();
|
||||||
let response = client.head(url).send().await?;
|
let response = client.head(url).send().await?;
|
||||||
|
|
||||||
if !response
|
if !response.remote_addr().map_or(false, |a| url_request_allowed(&a.ip())) {
|
||||||
.remote_addr()
|
|
||||||
.map_or(false, |a| url_request_allowed(&a.ip()))
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::Forbidden,
|
ErrorKind::Forbidden,
|
||||||
"Requesting from this address is forbidden",
|
"Requesting from this address is forbidden",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let content_type = match response
|
let content_type = match response.headers().get(reqwest::header::CONTENT_TYPE).and_then(|x| x.to_str().ok()) {
|
||||||
.headers()
|
|
||||||
.get(reqwest::header::CONTENT_TYPE)
|
|
||||||
.and_then(|x| x.to_str().ok())
|
|
||||||
{
|
|
||||||
Some(ct) => ct,
|
Some(ct) => ct,
|
||||||
None => {
|
None => return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")),
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Unknown,
|
|
||||||
"Unknown Content-Type",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
let data = match content_type {
|
let data = match content_type {
|
||||||
html if html.starts_with("text/html") => download_html(&client, url).await?,
|
html if html.starts_with("text/html") => download_html(&client, url).await?,
|
||||||
img if img.starts_with("image/") => download_image(&client, url).await?,
|
img if img.starts_with("image/") => download_image(&client, url).await?,
|
||||||
_ => {
|
_ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")),
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Unknown,
|
|
||||||
"Unsupported Content-Type",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
services().media.set_url_preview(url, &data).await?;
|
services().media.set_url_preview(url, &data).await?;
|
||||||
|
@ -479,15 +440,8 @@ async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure that only one request is made per URL
|
// ensure that only one request is made per URL
|
||||||
let mutex_request = Arc::clone(
|
let mutex_request =
|
||||||
services()
|
Arc::clone(services().media.url_preview_mutex.write().unwrap().entry(url.to_owned()).or_default());
|
||||||
.media
|
|
||||||
.url_preview_mutex
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.entry(url.to_owned())
|
|
||||||
.or_default(),
|
|
||||||
);
|
|
||||||
let _request_lock = mutex_request.lock().await;
|
let _request_lock = mutex_request.lock().await;
|
||||||
|
|
||||||
match services().media.get_url_preview(url).await {
|
match services().media.get_url_preview(url).await {
|
||||||
|
@ -502,25 +456,19 @@ fn url_preview_allowed(url_str: &str) -> bool {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Failed to parse URL from a str: {}", e);
|
warn!("Failed to parse URL from a str: {}", e);
|
||||||
return false;
|
return false;
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
if ["http", "https"]
|
if ["http", "https"].iter().all(|&scheme| scheme != url.scheme().to_lowercase()) {
|
||||||
.iter()
|
|
||||||
.all(|&scheme| scheme != url.scheme().to_lowercase())
|
|
||||||
{
|
|
||||||
debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url);
|
debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url);
|
||||||
return false;
|
return false;
|
||||||
}
|
}
|
||||||
|
|
||||||
let host = match url.host_str() {
|
let host = match url.host_str() {
|
||||||
None => {
|
None => {
|
||||||
debug!(
|
debug!("Ignoring URL preview for a URL that does not have a host (?): {}", url);
|
||||||
"Ignoring URL preview for a URL that does not have a host (?): {}",
|
|
||||||
url
|
|
||||||
);
|
|
||||||
return false;
|
return false;
|
||||||
}
|
},
|
||||||
Some(h) => h.to_owned(),
|
Some(h) => h.to_owned(),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -532,41 +480,23 @@ fn url_preview_allowed(url_str: &str) -> bool {
|
||||||
|| allowlist_domain_explicit.contains(&"*".to_owned())
|
|| allowlist_domain_explicit.contains(&"*".to_owned())
|
||||||
|| allowlist_url_contains.contains(&"*".to_owned())
|
|| allowlist_url_contains.contains(&"*".to_owned())
|
||||||
{
|
{
|
||||||
debug!(
|
debug!("Config key contains * which is allowing all URL previews. Allowing URL {}", url);
|
||||||
"Config key contains * which is allowing all URL previews. Allowing URL {}",
|
|
||||||
url
|
|
||||||
);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if !host.is_empty() {
|
if !host.is_empty() {
|
||||||
if allowlist_domain_explicit.contains(&host) {
|
if allowlist_domain_explicit.contains(&host) {
|
||||||
debug!(
|
debug!("Host {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)", &host);
|
||||||
"Host {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)",
|
|
||||||
&host
|
|
||||||
);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if allowlist_domain_contains
|
if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&host.clone())) {
|
||||||
.iter()
|
debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", &host);
|
||||||
.any(|domain_s| domain_s.contains(&host.clone()))
|
|
||||||
{
|
|
||||||
debug!(
|
|
||||||
"Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
|
|
||||||
&host
|
|
||||||
);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if allowlist_url_contains
|
if allowlist_url_contains.iter().any(|url_s| url.to_string().contains(&url_s.to_string())) {
|
||||||
.iter()
|
debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)", &host);
|
||||||
.any(|url_s| url.to_string().contains(&url_s.to_string()))
|
|
||||||
{
|
|
||||||
debug!(
|
|
||||||
"URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)",
|
|
||||||
&host
|
|
||||||
);
|
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -584,17 +514,14 @@ fn url_preview_allowed(url_str: &str) -> bool {
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if allowlist_domain_contains
|
if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&root_domain.to_owned())) {
|
||||||
.iter()
|
|
||||||
.any(|domain_s| domain_s.contains(&root_domain.to_owned()))
|
|
||||||
{
|
|
||||||
debug!(
|
debug!(
|
||||||
"Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
|
"Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
|
||||||
&root_domain
|
&root_domain
|
||||||
);
|
);
|
||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,7 +1,8 @@
|
||||||
use crate::{
|
use std::{
|
||||||
service::{pdu::PduBuilder, rooms::timeline::PduCount},
|
collections::{BTreeMap, HashSet},
|
||||||
services, utils, Error, Result, Ruma,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{
|
api::client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
|
@ -10,47 +11,40 @@ use ruma::{
|
||||||
events::{StateEventType, TimelineEventType},
|
events::{StateEventType, TimelineEventType},
|
||||||
};
|
};
|
||||||
use serde_json::from_str;
|
use serde_json::from_str;
|
||||||
use std::{
|
|
||||||
collections::{BTreeMap, HashSet},
|
use crate::{
|
||||||
sync::Arc,
|
service::{pdu::PduBuilder, rooms::timeline::PduCount},
|
||||||
|
services, utils, Error, Result, Ruma,
|
||||||
};
|
};
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}`
|
/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}`
|
||||||
///
|
///
|
||||||
/// Send a message event into the room.
|
/// Send a message event into the room.
|
||||||
///
|
///
|
||||||
/// - Is a NOOP if the txn id was already used before and returns the same event id again
|
/// - Is a NOOP if the txn id was already used before and returns the same event
|
||||||
|
/// id again
|
||||||
/// - The only requirement for the content is that it has to be valid json
|
/// - The only requirement for the content is that it has to be valid json
|
||||||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
/// - Tries to send the event into the room, auth rules will determine if it is
|
||||||
|
/// allowed
|
||||||
pub async fn send_message_event_route(
|
pub async fn send_message_event_route(
|
||||||
body: Ruma<send_message_event::v3::Request>,
|
body: Ruma<send_message_event::v3::Request>,
|
||||||
) -> Result<send_message_event::v3::Response> {
|
) -> Result<send_message_event::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let sender_device = body.sender_device.as_deref();
|
let sender_device = body.sender_device.as_deref();
|
||||||
|
|
||||||
let mutex_state = Arc::clone(
|
let mutex_state =
|
||||||
services()
|
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
|
||||||
.globals
|
|
||||||
.roomid_mutex_state
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.entry(body.room_id.clone())
|
|
||||||
.or_default(),
|
|
||||||
);
|
|
||||||
let state_lock = mutex_state.lock().await;
|
let state_lock = mutex_state.lock().await;
|
||||||
|
|
||||||
// Forbid m.room.encrypted if encryption is disabled
|
// Forbid m.room.encrypted if encryption is disabled
|
||||||
if TimelineEventType::RoomEncrypted == body.event_type.to_string().into()
|
if TimelineEventType::RoomEncrypted == body.event_type.to_string().into() && !services().globals.allow_encryption()
|
||||||
&& !services().globals.allow_encryption()
|
|
||||||
{
|
{
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Encryption has been disabled"));
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"Encryption has been disabled",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// certain event types require certain fields to be valid in request bodies.
|
// certain event types require certain fields to be valid in request bodies.
|
||||||
// this helps prevent attempting to handle events that we can't deserialise later so don't waste resources on it.
|
// this helps prevent attempting to handle events that we can't deserialise
|
||||||
|
// later so don't waste resources on it.
|
||||||
//
|
//
|
||||||
// see https://spec.matrix.org/v1.9/client-server-api/#events-2 for what's required per event type.
|
// see https://spec.matrix.org/v1.9/client-server-api/#events-2 for what's required per event type.
|
||||||
match body.event_type.to_string().into() {
|
match body.event_type.to_string().into() {
|
||||||
|
@ -71,7 +65,7 @@ pub async fn send_message_event_route(
|
||||||
"'msgtype' field in JSON request is invalid",
|
"'msgtype' field in JSON request is invalid",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
TimelineEventType::RoomName => {
|
TimelineEventType::RoomName => {
|
||||||
let name_field = body.body.body.get_field::<String>("name");
|
let name_field = body.body.body.get_field::<String>("name");
|
||||||
|
|
||||||
|
@ -81,7 +75,7 @@ pub async fn send_message_event_route(
|
||||||
"'name' field in JSON request is invalid",
|
"'name' field in JSON request is invalid",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
TimelineEventType::RoomTopic => {
|
TimelineEventType::RoomTopic => {
|
||||||
let topic_field = body.body.body.get_field::<String>("topic");
|
let topic_field = body.body.body.get_field::<String>("topic");
|
||||||
|
|
||||||
|
@ -91,16 +85,12 @@ pub async fn send_message_event_route(
|
||||||
"'topic' field in JSON request is invalid",
|
"'topic' field in JSON request is invalid",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
_ => {} // event may be custom/experimental or can be empty don't do anything with it
|
_ => {}, // event may be custom/experimental or can be empty don't do anything with it
|
||||||
};
|
};
|
||||||
|
|
||||||
// Check if this is a new transaction id
|
// Check if this is a new transaction id
|
||||||
if let Some(response) =
|
if let Some(response) = services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)? {
|
||||||
services()
|
|
||||||
.transaction_ids
|
|
||||||
.existing_txnid(sender_user, sender_device, &body.txn_id)?
|
|
||||||
{
|
|
||||||
// The client might have sent a txnid of the /sendToDevice endpoint
|
// The client might have sent a txnid of the /sendToDevice endpoint
|
||||||
// This txnid has no response associated with it
|
// This txnid has no response associated with it
|
||||||
if response.is_empty() {
|
if response.is_empty() {
|
||||||
|
@ -114,7 +104,9 @@ pub async fn send_message_event_route(
|
||||||
.map_err(|_| Error::bad_database("Invalid txnid bytes in database."))?
|
.map_err(|_| Error::bad_database("Invalid txnid bytes in database."))?
|
||||||
.try_into()
|
.try_into()
|
||||||
.map_err(|_| Error::bad_database("Invalid event id in txnid data."))?;
|
.map_err(|_| Error::bad_database("Invalid event id in txnid data."))?;
|
||||||
return Ok(send_message_event::v3::Response { event_id });
|
return Ok(send_message_event::v3::Response {
|
||||||
|
event_id,
|
||||||
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut unsigned = BTreeMap::new();
|
let mut unsigned = BTreeMap::new();
|
||||||
|
@ -138,25 +130,19 @@ pub async fn send_message_event_route(
|
||||||
)
|
)
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
services().transaction_ids.add_txnid(
|
services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?;
|
||||||
sender_user,
|
|
||||||
sender_device,
|
|
||||||
&body.txn_id,
|
|
||||||
event_id.as_bytes(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
drop(state_lock);
|
drop(state_lock);
|
||||||
|
|
||||||
Ok(send_message_event::v3::Response::new(
|
Ok(send_message_event::v3::Response::new((*event_id).to_owned()))
|
||||||
(*event_id).to_owned(),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/messages`
|
/// # `GET /_matrix/client/r0/rooms/{roomId}/messages`
|
||||||
///
|
///
|
||||||
/// Allows paginating through room history.
|
/// Allows paginating through room history.
|
||||||
///
|
///
|
||||||
/// - Only works if the user is joined (TODO: always allow, but only show events where the user was
|
/// - Only works if the user is joined (TODO: always allow, but only show events
|
||||||
|
/// where the user was
|
||||||
/// joined, depending on history_visibility)
|
/// joined, depending on history_visibility)
|
||||||
pub async fn get_message_events_route(
|
pub async fn get_message_events_route(
|
||||||
body: Ruma<get_message_events::v3::Request>,
|
body: Ruma<get_message_events::v3::Request>,
|
||||||
|
@ -172,17 +158,9 @@ pub async fn get_message_events_route(
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let to = body
|
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||||
.to
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
|
||||||
|
|
||||||
services().rooms.lazy_loading.lazy_load_confirm_delivery(
|
services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?;
|
||||||
sender_user,
|
|
||||||
sender_device,
|
|
||||||
&body.room_id,
|
|
||||||
from,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let limit = u64::from(body.limit).min(100) as usize;
|
let limit = u64::from(body.limit).min(100) as usize;
|
||||||
|
|
||||||
|
@ -228,21 +206,14 @@ pub async fn get_message_events_route(
|
||||||
|
|
||||||
next_token = events_after.last().map(|(count, _)| count).copied();
|
next_token = events_after.last().map(|(count, _)| count).copied();
|
||||||
|
|
||||||
let events_after: Vec<_> = events_after
|
let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||||
.into_iter()
|
|
||||||
.map(|(_, pdu)| pdu.to_room_event())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
resp.start = from.stringify();
|
resp.start = from.stringify();
|
||||||
resp.end = next_token.map(|count| count.stringify());
|
resp.end = next_token.map(|count| count.stringify());
|
||||||
resp.chunk = events_after;
|
resp.chunk = events_after;
|
||||||
}
|
},
|
||||||
ruma::api::Direction::Backward => {
|
ruma::api::Direction::Backward => {
|
||||||
services()
|
services().rooms.timeline.backfill_if_required(&body.room_id, from).await?;
|
||||||
.rooms
|
|
||||||
.timeline
|
|
||||||
.backfill_if_required(&body.room_id, from)
|
|
||||||
.await?;
|
|
||||||
let events_before: Vec<_> = services()
|
let events_before: Vec<_> = services()
|
||||||
.rooms
|
.rooms
|
||||||
.timeline
|
.timeline
|
||||||
|
@ -277,15 +248,12 @@ pub async fn get_message_events_route(
|
||||||
|
|
||||||
next_token = events_before.last().map(|(count, _)| count).copied();
|
next_token = events_before.last().map(|(count, _)| count).copied();
|
||||||
|
|
||||||
let events_before: Vec<_> = events_before
|
let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||||
.into_iter()
|
|
||||||
.map(|(_, pdu)| pdu.to_room_event())
|
|
||||||
.collect();
|
|
||||||
|
|
||||||
resp.start = from.stringify();
|
resp.start = from.stringify();
|
||||||
resp.end = next_token.map(|count| count.stringify());
|
resp.end = next_token.map(|count| count.stringify());
|
||||||
resp.chunk = events_before;
|
resp.chunk = events_before;
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
resp.state = Vec::new();
|
resp.state = Vec::new();
|
||||||
|
|
|
@ -1,21 +1,18 @@
|
||||||
use crate::{services, Error, Result, Ruma};
|
use std::time::Duration;
|
||||||
|
|
||||||
use ruma::api::client::{
|
use ruma::api::client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
presence::{get_presence, set_presence},
|
presence::{get_presence, set_presence},
|
||||||
};
|
};
|
||||||
use std::time::Duration;
|
|
||||||
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/presence/{userId}/status`
|
/// # `PUT /_matrix/client/r0/presence/{userId}/status`
|
||||||
///
|
///
|
||||||
/// Sets the presence state of the sender user.
|
/// Sets the presence state of the sender user.
|
||||||
pub async fn set_presence_route(
|
pub async fn set_presence_route(body: Ruma<set_presence::v3::Request>) -> Result<set_presence::v3::Response> {
|
||||||
body: Ruma<set_presence::v3::Request>,
|
|
||||||
) -> Result<set_presence::v3::Response> {
|
|
||||||
if !services().globals.allow_local_presence() {
|
if !services().globals.allow_local_presence() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Presence is disabled on this server"));
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"Presence is disabled on this server",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
@ -40,33 +37,19 @@ pub async fn set_presence_route(
|
||||||
/// Gets the presence state of the given user.
|
/// Gets the presence state of the given user.
|
||||||
///
|
///
|
||||||
/// - Only works if you share a room with the user
|
/// - Only works if you share a room with the user
|
||||||
pub async fn get_presence_route(
|
pub async fn get_presence_route(body: Ruma<get_presence::v3::Request>) -> Result<get_presence::v3::Response> {
|
||||||
body: Ruma<get_presence::v3::Request>,
|
|
||||||
) -> Result<get_presence::v3::Response> {
|
|
||||||
if !services().globals.allow_local_presence() {
|
if !services().globals.allow_local_presence() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Presence is disabled on this server"));
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"Presence is disabled on this server",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let mut presence_event = None;
|
let mut presence_event = None;
|
||||||
|
|
||||||
for room_id in services()
|
for room_id in services().rooms.user.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? {
|
||||||
.rooms
|
|
||||||
.user
|
|
||||||
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
|
|
||||||
{
|
|
||||||
let room_id = room_id?;
|
let room_id = room_id?;
|
||||||
|
|
||||||
if let Some(presence) = services()
|
if let Some(presence) = services().rooms.edus.presence.get_presence(&room_id, sender_user)? {
|
||||||
.rooms
|
|
||||||
.edus
|
|
||||||
.presence
|
|
||||||
.get_presence(&room_id, sender_user)?
|
|
||||||
{
|
|
||||||
presence_event = Some(presence);
|
presence_event = Some(presence);
|
||||||
break;
|
break;
|
||||||
}
|
}
|
||||||
|
@ -77,10 +60,7 @@ pub async fn get_presence_route(
|
||||||
// TODO: Should ruma just use the presenceeventcontent type here?
|
// TODO: Should ruma just use the presenceeventcontent type here?
|
||||||
status_msg: presence.content.status_msg,
|
status_msg: presence.content.status_msg,
|
||||||
currently_active: presence.content.currently_active,
|
currently_active: presence.content.currently_active,
|
||||||
last_active_ago: presence
|
last_active_ago: presence.content.last_active_ago.map(|millis| Duration::from_millis(millis.into())),
|
||||||
.content
|
|
||||||
.last_active_ago
|
|
||||||
.map(|millis| Duration::from_millis(millis.into())),
|
|
||||||
presence: presence.content.presence,
|
presence: presence.content.presence,
|
||||||
})
|
})
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -4,9 +4,7 @@ use ruma::{
|
||||||
api::{
|
api::{
|
||||||
client::{
|
client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
profile::{
|
profile::{get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name},
|
||||||
get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name,
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
federation,
|
federation,
|
||||||
},
|
},
|
||||||
|
@ -27,10 +25,7 @@ pub async fn set_displayname_route(
|
||||||
) -> Result<set_display_name::v3::Response> {
|
) -> Result<set_display_name::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
services()
|
services().users.set_displayname(sender_user, body.displayname.clone()).await?;
|
||||||
.users
|
|
||||||
.set_displayname(sender_user, body.displayname.clone())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Send a new membership event and presence update into all joined rooms
|
// Send a new membership event and presence update into all joined rooms
|
||||||
let all_rooms_joined: Vec<_> = services()
|
let all_rooms_joined: Vec<_> = services()
|
||||||
|
@ -48,16 +43,9 @@ pub async fn set_displayname_route(
|
||||||
services()
|
services()
|
||||||
.rooms
|
.rooms
|
||||||
.state_accessor
|
.state_accessor
|
||||||
.room_state_get(
|
.room_state_get(&room_id, &StateEventType::RoomMember, sender_user.as_str())?
|
||||||
&room_id,
|
|
||||||
&StateEventType::RoomMember,
|
|
||||||
sender_user.as_str(),
|
|
||||||
)?
|
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
Error::bad_database(
|
Error::bad_database("Tried to send displayname update for user not in the room.")
|
||||||
"Tried to send displayname update for user not in the \
|
|
||||||
room.",
|
|
||||||
)
|
|
||||||
})?
|
})?
|
||||||
.content
|
.content
|
||||||
.get(),
|
.get(),
|
||||||
|
@ -76,31 +64,16 @@ pub async fn set_displayname_route(
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
for (pdu_builder, room_id) in all_rooms_joined {
|
for (pdu_builder, room_id) in all_rooms_joined {
|
||||||
let mutex_state = Arc::clone(
|
let mutex_state =
|
||||||
services()
|
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
|
||||||
.globals
|
|
||||||
.roomid_mutex_state
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.entry(room_id.clone())
|
|
||||||
.or_default(),
|
|
||||||
);
|
|
||||||
let state_lock = mutex_state.lock().await;
|
let state_lock = mutex_state.lock().await;
|
||||||
|
|
||||||
let _ = services()
|
let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await;
|
||||||
.rooms
|
|
||||||
.timeline
|
|
||||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if services().globals.allow_local_presence() {
|
if services().globals.allow_local_presence() {
|
||||||
// Presence update
|
// Presence update
|
||||||
services()
|
services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?;
|
||||||
.rooms
|
|
||||||
.edus
|
|
||||||
.presence
|
|
||||||
.ping_presence(sender_user, PresenceState::Online)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(set_display_name::v3::Response {})
|
Ok(set_display_name::v3::Response {})
|
||||||
|
@ -132,18 +105,9 @@ pub async fn get_displayname_route(
|
||||||
services().users.create(&body.user_id, None)?;
|
services().users.create(&body.user_id, None)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
services()
|
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
|
||||||
.users
|
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
|
||||||
.set_displayname(&body.user_id, response.displayname.clone())
|
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
|
||||||
.await?;
|
|
||||||
services()
|
|
||||||
.users
|
|
||||||
.set_avatar_url(&body.user_id, response.avatar_url.clone())
|
|
||||||
.await?;
|
|
||||||
services()
|
|
||||||
.users
|
|
||||||
.set_blurhash(&body.user_id, response.blurhash.clone())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
return Ok(get_display_name::v3::Response {
|
return Ok(get_display_name::v3::Response {
|
||||||
displayname: response.displayname,
|
displayname: response.displayname,
|
||||||
|
@ -152,11 +116,9 @@ pub async fn get_displayname_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
if !services().users.exists(&body.user_id)? {
|
if !services().users.exists(&body.user_id)? {
|
||||||
// Return 404 if this user doesn't exist and we couldn't fetch it over federation
|
// Return 404 if this user doesn't exist and we couldn't fetch it over
|
||||||
return Err(Error::BadRequest(
|
// federation
|
||||||
ErrorKind::NotFound,
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
|
||||||
"Profile was not found.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(get_display_name::v3::Response {
|
Ok(get_display_name::v3::Response {
|
||||||
|
@ -169,20 +131,12 @@ pub async fn get_displayname_route(
|
||||||
/// Updates the avatar_url and blurhash.
|
/// Updates the avatar_url and blurhash.
|
||||||
///
|
///
|
||||||
/// - Also makes sure other users receive the update using presence EDUs
|
/// - Also makes sure other users receive the update using presence EDUs
|
||||||
pub async fn set_avatar_url_route(
|
pub async fn set_avatar_url_route(body: Ruma<set_avatar_url::v3::Request>) -> Result<set_avatar_url::v3::Response> {
|
||||||
body: Ruma<set_avatar_url::v3::Request>,
|
|
||||||
) -> Result<set_avatar_url::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
services()
|
services().users.set_avatar_url(sender_user, body.avatar_url.clone()).await?;
|
||||||
.users
|
|
||||||
.set_avatar_url(sender_user, body.avatar_url.clone())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
services()
|
services().users.set_blurhash(sender_user, body.blurhash.clone()).await?;
|
||||||
.users
|
|
||||||
.set_blurhash(sender_user, body.blurhash.clone())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
// Send a new membership event and presence update into all joined rooms
|
// Send a new membership event and presence update into all joined rooms
|
||||||
let all_joined_rooms: Vec<_> = services()
|
let all_joined_rooms: Vec<_> = services()
|
||||||
|
@ -200,16 +154,9 @@ pub async fn set_avatar_url_route(
|
||||||
services()
|
services()
|
||||||
.rooms
|
.rooms
|
||||||
.state_accessor
|
.state_accessor
|
||||||
.room_state_get(
|
.room_state_get(&room_id, &StateEventType::RoomMember, sender_user.as_str())?
|
||||||
&room_id,
|
|
||||||
&StateEventType::RoomMember,
|
|
||||||
sender_user.as_str(),
|
|
||||||
)?
|
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| {
|
||||||
Error::bad_database(
|
Error::bad_database("Tried to send displayname update for user not in the room.")
|
||||||
"Tried to send displayname update for user not in the \
|
|
||||||
room.",
|
|
||||||
)
|
|
||||||
})?
|
})?
|
||||||
.content
|
.content
|
||||||
.get(),
|
.get(),
|
||||||
|
@ -228,31 +175,16 @@ pub async fn set_avatar_url_route(
|
||||||
.collect();
|
.collect();
|
||||||
|
|
||||||
for (pdu_builder, room_id) in all_joined_rooms {
|
for (pdu_builder, room_id) in all_joined_rooms {
|
||||||
let mutex_state = Arc::clone(
|
let mutex_state =
|
||||||
services()
|
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
|
||||||
.globals
|
|
||||||
.roomid_mutex_state
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.entry(room_id.clone())
|
|
||||||
.or_default(),
|
|
||||||
);
|
|
||||||
let state_lock = mutex_state.lock().await;
|
let state_lock = mutex_state.lock().await;
|
||||||
|
|
||||||
let _ = services()
|
let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await;
|
||||||
.rooms
|
|
||||||
.timeline
|
|
||||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
|
|
||||||
.await;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if services().globals.allow_local_presence() {
|
if services().globals.allow_local_presence() {
|
||||||
// Presence update
|
// Presence update
|
||||||
services()
|
services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?;
|
||||||
.rooms
|
|
||||||
.edus
|
|
||||||
.presence
|
|
||||||
.ping_presence(sender_user, PresenceState::Online)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(set_avatar_url::v3::Response {})
|
Ok(set_avatar_url::v3::Response {})
|
||||||
|
@ -264,9 +196,7 @@ pub async fn set_avatar_url_route(
|
||||||
///
|
///
|
||||||
/// - If user is on another server and we do not have a local copy already
|
/// - If user is on another server and we do not have a local copy already
|
||||||
/// fetch avatar_url and blurhash over federation
|
/// fetch avatar_url and blurhash over federation
|
||||||
pub async fn get_avatar_url_route(
|
pub async fn get_avatar_url_route(body: Ruma<get_avatar_url::v3::Request>) -> Result<get_avatar_url::v3::Response> {
|
||||||
body: Ruma<get_avatar_url::v3::Request>,
|
|
||||||
) -> Result<get_avatar_url::v3::Response> {
|
|
||||||
if body.user_id.server_name() != services().globals.server_name() {
|
if body.user_id.server_name() != services().globals.server_name() {
|
||||||
// Create and update our local copy of the user
|
// Create and update our local copy of the user
|
||||||
if let Ok(response) = services()
|
if let Ok(response) = services()
|
||||||
|
@ -284,18 +214,9 @@ pub async fn get_avatar_url_route(
|
||||||
services().users.create(&body.user_id, None)?;
|
services().users.create(&body.user_id, None)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
services()
|
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
|
||||||
.users
|
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
|
||||||
.set_displayname(&body.user_id, response.displayname.clone())
|
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
|
||||||
.await?;
|
|
||||||
services()
|
|
||||||
.users
|
|
||||||
.set_avatar_url(&body.user_id, response.avatar_url.clone())
|
|
||||||
.await?;
|
|
||||||
services()
|
|
||||||
.users
|
|
||||||
.set_blurhash(&body.user_id, response.blurhash.clone())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
return Ok(get_avatar_url::v3::Response {
|
return Ok(get_avatar_url::v3::Response {
|
||||||
avatar_url: response.avatar_url,
|
avatar_url: response.avatar_url,
|
||||||
|
@ -305,11 +226,9 @@ pub async fn get_avatar_url_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
if !services().users.exists(&body.user_id)? {
|
if !services().users.exists(&body.user_id)? {
|
||||||
// Return 404 if this user doesn't exist and we couldn't fetch it over federation
|
// Return 404 if this user doesn't exist and we couldn't fetch it over
|
||||||
return Err(Error::BadRequest(
|
// federation
|
||||||
ErrorKind::NotFound,
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
|
||||||
"Profile was not found.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(get_avatar_url::v3::Response {
|
Ok(get_avatar_url::v3::Response {
|
||||||
|
@ -324,9 +243,7 @@ pub async fn get_avatar_url_route(
|
||||||
///
|
///
|
||||||
/// - If user is on another server and we do not have a local copy already,
|
/// - If user is on another server and we do not have a local copy already,
|
||||||
/// fetch profile over federation.
|
/// fetch profile over federation.
|
||||||
pub async fn get_profile_route(
|
pub async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> Result<get_profile::v3::Response> {
|
||||||
body: Ruma<get_profile::v3::Request>,
|
|
||||||
) -> Result<get_profile::v3::Response> {
|
|
||||||
if body.user_id.server_name() != services().globals.server_name() {
|
if body.user_id.server_name() != services().globals.server_name() {
|
||||||
// Create and update our local copy of the user
|
// Create and update our local copy of the user
|
||||||
if let Ok(response) = services()
|
if let Ok(response) = services()
|
||||||
|
@ -344,18 +261,9 @@ pub async fn get_profile_route(
|
||||||
services().users.create(&body.user_id, None)?;
|
services().users.create(&body.user_id, None)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
services()
|
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
|
||||||
.users
|
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
|
||||||
.set_displayname(&body.user_id, response.displayname.clone())
|
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
|
||||||
.await?;
|
|
||||||
services()
|
|
||||||
.users
|
|
||||||
.set_avatar_url(&body.user_id, response.avatar_url.clone())
|
|
||||||
.await?;
|
|
||||||
services()
|
|
||||||
.users
|
|
||||||
.set_blurhash(&body.user_id, response.blurhash.clone())
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
return Ok(get_profile::v3::Response {
|
return Ok(get_profile::v3::Response {
|
||||||
displayname: response.displayname,
|
displayname: response.displayname,
|
||||||
|
@ -366,11 +274,9 @@ pub async fn get_profile_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
if !services().users.exists(&body.user_id)? {
|
if !services().users.exists(&body.user_id)? {
|
||||||
// Return 404 if this user doesn't exist and we couldn't fetch it over federation
|
// Return 404 if this user doesn't exist and we couldn't fetch it over
|
||||||
return Err(Error::BadRequest(
|
// federation
|
||||||
ErrorKind::NotFound,
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
|
||||||
"Profile was not found.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(get_profile::v3::Response {
|
Ok(get_profile::v3::Response {
|
||||||
|
|
|
@ -1,17 +1,17 @@
|
||||||
use crate::{services, Error, Result, Ruma};
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{
|
api::client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
push::{
|
push::{
|
||||||
delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled,
|
delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, get_pushrules_all,
|
||||||
get_pushrules_all, set_pusher, set_pushrule, set_pushrule_actions,
|
set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, RuleScope,
|
||||||
set_pushrule_enabled, RuleScope,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
events::{push_rules::PushRulesEvent, GlobalAccountDataEventType},
|
events::{push_rules::PushRulesEvent, GlobalAccountDataEventType},
|
||||||
push::{InsertPushRuleError, RemovePushRuleError},
|
push::{InsertPushRuleError, RemovePushRuleError},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/pushrules`
|
/// # `GET /_matrix/client/r0/pushrules`
|
||||||
///
|
///
|
||||||
/// Retrieves the push rules event for this user.
|
/// Retrieves the push rules event for this user.
|
||||||
|
@ -22,15 +22,8 @@ pub async fn get_pushrules_all_route(
|
||||||
|
|
||||||
let event = services()
|
let event = services()
|
||||||
.account_data
|
.account_data
|
||||||
.get(
|
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||||
None,
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||||
sender_user,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
)?
|
|
||||||
.ok_or(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"PushRules event not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||||
|
@ -44,48 +37,33 @@ pub async fn get_pushrules_all_route(
|
||||||
/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
||||||
///
|
///
|
||||||
/// Retrieves a single specified push rule for this user.
|
/// Retrieves a single specified push rule for this user.
|
||||||
pub async fn get_pushrule_route(
|
pub async fn get_pushrule_route(body: Ruma<get_pushrule::v3::Request>) -> Result<get_pushrule::v3::Response> {
|
||||||
body: Ruma<get_pushrule::v3::Request>,
|
|
||||||
) -> Result<get_pushrule::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let event = services()
|
let event = services()
|
||||||
.account_data
|
.account_data
|
||||||
.get(
|
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||||
None,
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||||
sender_user,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
)?
|
|
||||||
.ok_or(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"PushRules event not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||||
.content;
|
.content;
|
||||||
|
|
||||||
let rule = account_data
|
let rule = account_data.global.get(body.kind.clone(), &body.rule_id).map(Into::into);
|
||||||
.global
|
|
||||||
.get(body.kind.clone(), &body.rule_id)
|
|
||||||
.map(Into::into);
|
|
||||||
|
|
||||||
if let Some(rule) = rule {
|
if let Some(rule) = rule {
|
||||||
Ok(get_pushrule::v3::Response { rule })
|
Ok(get_pushrule::v3::Response {
|
||||||
|
rule,
|
||||||
|
})
|
||||||
} else {
|
} else {
|
||||||
Err(Error::BadRequest(
|
Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Push rule not found.",
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
||||||
///
|
///
|
||||||
/// Creates a single specified push rule for this user.
|
/// Creates a single specified push rule for this user.
|
||||||
pub async fn set_pushrule_route(
|
pub async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) -> Result<set_pushrule::v3::Response> {
|
||||||
body: Ruma<set_pushrule::v3::Request>,
|
|
||||||
) -> Result<set_pushrule::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let body = body.body;
|
let body = body.body;
|
||||||
|
|
||||||
|
@ -98,41 +76,30 @@ pub async fn set_pushrule_route(
|
||||||
|
|
||||||
let event = services()
|
let event = services()
|
||||||
.account_data
|
.account_data
|
||||||
.get(
|
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||||
None,
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||||
sender_user,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
)?
|
|
||||||
.ok_or(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"PushRules event not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||||
|
|
||||||
if let Err(error) = account_data.content.global.insert(
|
if let Err(error) =
|
||||||
body.rule.clone(),
|
account_data.content.global.insert(body.rule.clone(), body.after.as_deref(), body.before.as_deref())
|
||||||
body.after.as_deref(),
|
{
|
||||||
body.before.as_deref(),
|
|
||||||
) {
|
|
||||||
let err = match error {
|
let err = match error {
|
||||||
InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest(
|
InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Rule IDs starting with a dot are reserved for server-default rules.",
|
"Rule IDs starting with a dot are reserved for server-default rules.",
|
||||||
),
|
),
|
||||||
InsertPushRuleError::InvalidRuleId => Error::BadRequest(
|
InsertPushRuleError::InvalidRuleId => {
|
||||||
ErrorKind::InvalidParam,
|
Error::BadRequest(ErrorKind::InvalidParam, "Rule ID containing invalid characters.")
|
||||||
"Rule ID containing invalid characters.",
|
},
|
||||||
),
|
|
||||||
InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest(
|
InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Can't place a push rule relatively to a server-default rule.",
|
"Can't place a push rule relatively to a server-default rule.",
|
||||||
),
|
),
|
||||||
InsertPushRuleError::UnknownRuleId => Error::BadRequest(
|
InsertPushRuleError::UnknownRuleId => {
|
||||||
ErrorKind::NotFound,
|
Error::BadRequest(ErrorKind::NotFound, "The before or after rule could not be found.")
|
||||||
"The before or after rule could not be found.",
|
},
|
||||||
),
|
|
||||||
InsertPushRuleError::BeforeHigherThanAfter => Error::BadRequest(
|
InsertPushRuleError::BeforeHigherThanAfter => Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"The before rule has a higher priority than the after rule.",
|
"The before rule has a higher priority than the after rule.",
|
||||||
|
@ -170,15 +137,8 @@ pub async fn get_pushrule_actions_route(
|
||||||
|
|
||||||
let event = services()
|
let event = services()
|
||||||
.account_data
|
.account_data
|
||||||
.get(
|
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||||
None,
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||||
sender_user,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
)?
|
|
||||||
.ok_or(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"PushRules event not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||||
|
@ -188,12 +148,11 @@ pub async fn get_pushrule_actions_route(
|
||||||
let actions = global
|
let actions = global
|
||||||
.get(body.kind.clone(), &body.rule_id)
|
.get(body.kind.clone(), &body.rule_id)
|
||||||
.map(|rule| rule.actions().to_owned())
|
.map(|rule| rule.actions().to_owned())
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?;
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Push rule not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
Ok(get_pushrule_actions::v3::Response { actions })
|
Ok(get_pushrule_actions::v3::Response {
|
||||||
|
actions,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions`
|
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions`
|
||||||
|
@ -213,29 +172,14 @@ pub async fn set_pushrule_actions_route(
|
||||||
|
|
||||||
let event = services()
|
let event = services()
|
||||||
.account_data
|
.account_data
|
||||||
.get(
|
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||||
None,
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||||
sender_user,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
)?
|
|
||||||
.ok_or(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"PushRules event not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||||
|
|
||||||
if account_data
|
if account_data.content.global.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone()).is_err() {
|
||||||
.content
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
|
||||||
.global
|
|
||||||
.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone())
|
|
||||||
.is_err()
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Push rule not found.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
services().account_data.update(
|
services().account_data.update(
|
||||||
|
@ -265,15 +209,8 @@ pub async fn get_pushrule_enabled_route(
|
||||||
|
|
||||||
let event = services()
|
let event = services()
|
||||||
.account_data
|
.account_data
|
||||||
.get(
|
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||||
None,
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||||
sender_user,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
)?
|
|
||||||
.ok_or(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"PushRules event not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||||
|
@ -282,12 +219,11 @@ pub async fn get_pushrule_enabled_route(
|
||||||
let enabled = global
|
let enabled = global
|
||||||
.get(body.kind.clone(), &body.rule_id)
|
.get(body.kind.clone(), &body.rule_id)
|
||||||
.map(ruma::push::AnyPushRuleRef::enabled)
|
.map(ruma::push::AnyPushRuleRef::enabled)
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?;
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Push rule not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
Ok(get_pushrule_enabled::v3::Response { enabled })
|
Ok(get_pushrule_enabled::v3::Response {
|
||||||
|
enabled,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled`
|
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled`
|
||||||
|
@ -307,29 +243,14 @@ pub async fn set_pushrule_enabled_route(
|
||||||
|
|
||||||
let event = services()
|
let event = services()
|
||||||
.account_data
|
.account_data
|
||||||
.get(
|
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||||
None,
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||||
sender_user,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
)?
|
|
||||||
.ok_or(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"PushRules event not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||||
|
|
||||||
if account_data
|
if account_data.content.global.set_enabled(body.kind.clone(), &body.rule_id, body.enabled).is_err() {
|
||||||
.content
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
|
||||||
.global
|
|
||||||
.set_enabled(body.kind.clone(), &body.rule_id, body.enabled)
|
|
||||||
.is_err()
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Push rule not found.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
services().account_data.update(
|
services().account_data.update(
|
||||||
|
@ -345,9 +266,7 @@ pub async fn set_pushrule_enabled_route(
|
||||||
/// # `DELETE /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
/// # `DELETE /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
||||||
///
|
///
|
||||||
/// Deletes a single specified push rule for this user.
|
/// Deletes a single specified push rule for this user.
|
||||||
pub async fn delete_pushrule_route(
|
pub async fn delete_pushrule_route(body: Ruma<delete_pushrule::v3::Request>) -> Result<delete_pushrule::v3::Response> {
|
||||||
body: Ruma<delete_pushrule::v3::Request>,
|
|
||||||
) -> Result<delete_pushrule::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if body.scope != RuleScope::Global {
|
if body.scope != RuleScope::Global {
|
||||||
|
@ -359,32 +278,18 @@ pub async fn delete_pushrule_route(
|
||||||
|
|
||||||
let event = services()
|
let event = services()
|
||||||
.account_data
|
.account_data
|
||||||
.get(
|
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||||
None,
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||||
sender_user,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
)?
|
|
||||||
.ok_or(Error::BadRequest(
|
|
||||||
ErrorKind::NotFound,
|
|
||||||
"PushRules event not found.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||||
|
|
||||||
if let Err(error) = account_data
|
if let Err(error) = account_data.content.global.remove(body.kind.clone(), &body.rule_id) {
|
||||||
.content
|
|
||||||
.global
|
|
||||||
.remove(body.kind.clone(), &body.rule_id)
|
|
||||||
{
|
|
||||||
let err = match error {
|
let err = match error {
|
||||||
RemovePushRuleError::ServerDefault => Error::BadRequest(
|
RemovePushRuleError::ServerDefault => {
|
||||||
ErrorKind::InvalidParam,
|
Error::BadRequest(ErrorKind::InvalidParam, "Cannot delete a server-default pushrule.")
|
||||||
"Cannot delete a server-default pushrule.",
|
},
|
||||||
),
|
RemovePushRuleError::NotFound => Error::BadRequest(ErrorKind::NotFound, "Push rule not found."),
|
||||||
RemovePushRuleError::NotFound => {
|
|
||||||
Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")
|
|
||||||
}
|
|
||||||
_ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
|
_ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -404,9 +309,7 @@ pub async fn delete_pushrule_route(
|
||||||
/// # `GET /_matrix/client/r0/pushers`
|
/// # `GET /_matrix/client/r0/pushers`
|
||||||
///
|
///
|
||||||
/// Gets all currently active pushers for the sender user.
|
/// Gets all currently active pushers for the sender user.
|
||||||
pub async fn get_pushers_route(
|
pub async fn get_pushers_route(body: Ruma<get_pushers::v3::Request>) -> Result<get_pushers::v3::Response> {
|
||||||
body: Ruma<get_pushers::v3::Request>,
|
|
||||||
) -> Result<get_pushers::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
Ok(get_pushers::v3::Response {
|
Ok(get_pushers::v3::Response {
|
||||||
|
@ -419,14 +322,10 @@ pub async fn get_pushers_route(
|
||||||
/// Adds a pusher for the sender user.
|
/// Adds a pusher for the sender user.
|
||||||
///
|
///
|
||||||
/// - TODO: Handle `append`
|
/// - TODO: Handle `append`
|
||||||
pub async fn set_pushers_route(
|
pub async fn set_pushers_route(body: Ruma<set_pusher::v3::Request>) -> Result<set_pusher::v3::Response> {
|
||||||
body: Ruma<set_pusher::v3::Request>,
|
|
||||||
) -> Result<set_pusher::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
services()
|
services().pusher.set_pusher(sender_user, body.action.clone())?;
|
||||||
.pusher
|
|
||||||
.set_pusher(sender_user, body.action.clone())?;
|
|
||||||
|
|
||||||
Ok(set_pusher::v3::Response::default())
|
Ok(set_pusher::v3::Response::default())
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
|
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
|
||||||
events::{
|
events::{
|
||||||
|
@ -7,17 +8,17 @@ use ruma::{
|
||||||
},
|
},
|
||||||
MilliSecondsSinceUnixEpoch,
|
MilliSecondsSinceUnixEpoch,
|
||||||
};
|
};
|
||||||
use std::collections::BTreeMap;
|
|
||||||
|
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers`
|
/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers`
|
||||||
///
|
///
|
||||||
/// Sets different types of read markers.
|
/// Sets different types of read markers.
|
||||||
///
|
///
|
||||||
/// - Updates fully-read account data event to `fully_read`
|
/// - Updates fully-read account data event to `fully_read`
|
||||||
/// - If `read_receipt` is set: Update private marker and public read receipt EDU
|
/// - If `read_receipt` is set: Update private marker and public read receipt
|
||||||
pub async fn set_read_marker_route(
|
/// EDU
|
||||||
body: Ruma<set_read_marker::v3::Request>,
|
pub async fn set_read_marker_route(body: Ruma<set_read_marker::v3::Request>) -> Result<set_read_marker::v3::Response> {
|
||||||
) -> Result<set_read_marker::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if let Some(fully_read) = &body.fully_read {
|
if let Some(fully_read) = &body.fully_read {
|
||||||
|
@ -35,10 +36,7 @@ pub async fn set_read_marker_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
|
if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
|
||||||
services()
|
services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?;
|
||||||
.rooms
|
|
||||||
.user
|
|
||||||
.reset_notification_counts(sender_user, &body.room_id)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(event) = &body.private_read_receipt {
|
if let Some(event) = &body.private_read_receipt {
|
||||||
|
@ -46,24 +44,17 @@ pub async fn set_read_marker_route(
|
||||||
.rooms
|
.rooms
|
||||||
.timeline
|
.timeline
|
||||||
.get_pdu_count(event)?
|
.get_pdu_count(event)?
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Event does not exist.",
|
|
||||||
))?;
|
|
||||||
let count = match count {
|
let count = match count {
|
||||||
PduCount::Backfilled(_) => {
|
PduCount::Backfilled(_) => {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Read receipt is in backfilled timeline",
|
"Read receipt is in backfilled timeline",
|
||||||
))
|
))
|
||||||
}
|
},
|
||||||
PduCount::Normal(c) => c,
|
PduCount::Normal(c) => c,
|
||||||
};
|
};
|
||||||
services()
|
services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?;
|
||||||
.rooms
|
|
||||||
.edus
|
|
||||||
.read_receipt
|
|
||||||
.private_read_set(&body.room_id, sender_user, count)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(event) = &body.read_receipt {
|
if let Some(event) = &body.read_receipt {
|
||||||
|
@ -98,19 +89,14 @@ pub async fn set_read_marker_route(
|
||||||
/// # `POST /_matrix/client/r0/rooms/{roomId}/receipt/{receiptType}/{eventId}`
|
/// # `POST /_matrix/client/r0/rooms/{roomId}/receipt/{receiptType}/{eventId}`
|
||||||
///
|
///
|
||||||
/// Sets private read marker and public read receipt EDU.
|
/// Sets private read marker and public read receipt EDU.
|
||||||
pub async fn create_receipt_route(
|
pub async fn create_receipt_route(body: Ruma<create_receipt::v3::Request>) -> Result<create_receipt::v3::Response> {
|
||||||
body: Ruma<create_receipt::v3::Request>,
|
|
||||||
) -> Result<create_receipt::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if matches!(
|
if matches!(
|
||||||
&body.receipt_type,
|
&body.receipt_type,
|
||||||
create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate
|
create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate
|
||||||
) {
|
) {
|
||||||
services()
|
services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?;
|
||||||
.rooms
|
|
||||||
.user
|
|
||||||
.reset_notification_counts(sender_user, &body.room_id)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
match body.receipt_type {
|
match body.receipt_type {
|
||||||
|
@ -126,7 +112,7 @@ pub async fn create_receipt_route(
|
||||||
RoomAccountDataEventType::FullyRead,
|
RoomAccountDataEventType::FullyRead,
|
||||||
&serde_json::to_value(fully_read_event).expect("to json value always works"),
|
&serde_json::to_value(fully_read_event).expect("to json value always works"),
|
||||||
)?;
|
)?;
|
||||||
}
|
},
|
||||||
create_receipt::v3::ReceiptType::Read => {
|
create_receipt::v3::ReceiptType::Read => {
|
||||||
let mut user_receipts = BTreeMap::new();
|
let mut user_receipts = BTreeMap::new();
|
||||||
user_receipts.insert(
|
user_receipts.insert(
|
||||||
|
@ -150,31 +136,24 @@ pub async fn create_receipt_route(
|
||||||
room_id: body.room_id.clone(),
|
room_id: body.room_id.clone(),
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
}
|
},
|
||||||
create_receipt::v3::ReceiptType::ReadPrivate => {
|
create_receipt::v3::ReceiptType::ReadPrivate => {
|
||||||
let count = services()
|
let count = services()
|
||||||
.rooms
|
.rooms
|
||||||
.timeline
|
.timeline
|
||||||
.get_pdu_count(&body.event_id)?
|
.get_pdu_count(&body.event_id)?
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Event does not exist.",
|
|
||||||
))?;
|
|
||||||
let count = match count {
|
let count = match count {
|
||||||
PduCount::Backfilled(_) => {
|
PduCount::Backfilled(_) => {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Read receipt is in backfilled timeline",
|
"Read receipt is in backfilled timeline",
|
||||||
))
|
))
|
||||||
}
|
},
|
||||||
PduCount::Normal(c) => c,
|
PduCount::Normal(c) => c,
|
||||||
};
|
};
|
||||||
services().rooms.edus.read_receipt.private_read_set(
|
services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?;
|
||||||
&body.room_id,
|
},
|
||||||
sender_user,
|
|
||||||
count,
|
|
||||||
)?;
|
|
||||||
}
|
|
||||||
_ => return Err(Error::bad_database("Unsupported receipt type")),
|
_ => return Err(Error::bad_database("Unsupported receipt type")),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,33 +1,24 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::redact::redact_event,
|
api::client::redact::redact_event,
|
||||||
events::{room::redaction::RoomRedactionEventContent, TimelineEventType},
|
events::{room::redaction::RoomRedactionEventContent, TimelineEventType},
|
||||||
};
|
};
|
||||||
|
|
||||||
use serde_json::value::to_raw_value;
|
use serde_json::value::to_raw_value;
|
||||||
|
|
||||||
|
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}`
|
/// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}`
|
||||||
///
|
///
|
||||||
/// Tries to send a redaction event into the room.
|
/// Tries to send a redaction event into the room.
|
||||||
///
|
///
|
||||||
/// - TODO: Handle txn id
|
/// - TODO: Handle txn id
|
||||||
pub async fn redact_event_route(
|
pub async fn redact_event_route(body: Ruma<redact_event::v3::Request>) -> Result<redact_event::v3::Response> {
|
||||||
body: Ruma<redact_event::v3::Request>,
|
|
||||||
) -> Result<redact_event::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let body = body.body;
|
let body = body.body;
|
||||||
|
|
||||||
let mutex_state = Arc::clone(
|
let mutex_state =
|
||||||
services()
|
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
|
||||||
.globals
|
|
||||||
.roomid_mutex_state
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.entry(body.room_id.clone())
|
|
||||||
.or_default(),
|
|
||||||
);
|
|
||||||
let state_lock = mutex_state.lock().await;
|
let state_lock = mutex_state.lock().await;
|
||||||
|
|
||||||
let event_id = services()
|
let event_id = services()
|
||||||
|
@ -54,5 +45,7 @@ pub async fn redact_event_route(
|
||||||
drop(state_lock);
|
drop(state_lock);
|
||||||
|
|
||||||
let event_id = (*event_id).to_owned();
|
let event_id = (*event_id).to_owned();
|
||||||
Ok(redact_event::v3::Response { event_id })
|
Ok(redact_event::v3::Response {
|
||||||
|
event_id,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use ruma::api::client::relations::{
|
use ruma::api::client::relations::{
|
||||||
get_relating_events, get_relating_events_with_rel_type,
|
get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type,
|
||||||
get_relating_events_with_rel_type_and_event_type,
|
|
||||||
};
|
};
|
||||||
|
|
||||||
use crate::{service::rooms::timeline::PduCount, services, Result, Ruma};
|
use crate::{service::rooms::timeline::PduCount, services, Result, Ruma};
|
||||||
|
@ -20,22 +19,12 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let to = body
|
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||||
.to
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
|
||||||
|
|
||||||
// Use limit or else 10, with maximum 100
|
// Use limit or else 10, with maximum 100
|
||||||
let limit = body
|
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
|
||||||
.limit
|
|
||||||
.and_then(|u| u32::try_from(u).ok())
|
|
||||||
.map_or(10_usize, |u| u as usize)
|
|
||||||
.min(100);
|
|
||||||
|
|
||||||
let res = services()
|
let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
|
||||||
.rooms
|
|
||||||
.pdu_metadata
|
|
||||||
.paginate_relations_with_filter(
|
|
||||||
sender_user,
|
sender_user,
|
||||||
&body.room_id,
|
&body.room_id,
|
||||||
&body.event_id,
|
&body.event_id,
|
||||||
|
@ -46,13 +35,11 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
|
||||||
limit,
|
limit,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(
|
Ok(get_relating_events_with_rel_type_and_event_type::v1::Response {
|
||||||
get_relating_events_with_rel_type_and_event_type::v1::Response {
|
|
||||||
chunk: res.chunk,
|
chunk: res.chunk,
|
||||||
next_batch: res.next_batch,
|
next_batch: res.next_batch,
|
||||||
prev_batch: res.prev_batch,
|
prev_batch: res.prev_batch,
|
||||||
},
|
})
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}`
|
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}`
|
||||||
|
@ -70,22 +57,12 @@ pub async fn get_relating_events_with_rel_type_route(
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let to = body
|
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||||
.to
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
|
||||||
|
|
||||||
// Use limit or else 10, with maximum 100
|
// Use limit or else 10, with maximum 100
|
||||||
let limit = body
|
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
|
||||||
.limit
|
|
||||||
.and_then(|u| u32::try_from(u).ok())
|
|
||||||
.map_or(10_usize, |u| u as usize)
|
|
||||||
.min(100);
|
|
||||||
|
|
||||||
let res = services()
|
let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
|
||||||
.rooms
|
|
||||||
.pdu_metadata
|
|
||||||
.paginate_relations_with_filter(
|
|
||||||
sender_user,
|
sender_user,
|
||||||
&body.room_id,
|
&body.room_id,
|
||||||
&body.event_id,
|
&body.event_id,
|
||||||
|
@ -118,22 +95,12 @@ pub async fn get_relating_events_route(
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let to = body
|
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||||
.to
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
|
||||||
|
|
||||||
// Use limit or else 10, with maximum 100
|
// Use limit or else 10, with maximum 100
|
||||||
let limit = body
|
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
|
||||||
.limit
|
|
||||||
.and_then(|u| u32::try_from(u).ok())
|
|
||||||
.map_or(10_usize, |u| u as usize)
|
|
||||||
.min(100);
|
|
||||||
|
|
||||||
services()
|
services().rooms.pdu_metadata.paginate_relations_with_filter(
|
||||||
.rooms
|
|
||||||
.pdu_metadata
|
|
||||||
.paginate_relations_with_filter(
|
|
||||||
sender_user,
|
sender_user,
|
||||||
&body.room_id,
|
&body.room_id,
|
||||||
&body.event_id,
|
&body.event_id,
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use crate::{services, utils::HtmlEscape, Error, Result, Ruma};
|
|
||||||
use rand::Rng;
|
use rand::Rng;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{error::ErrorKind, room::report_content},
|
api::client::{error::ErrorKind, room::report_content},
|
||||||
|
@ -10,13 +9,12 @@ use ruma::{
|
||||||
use tokio::time::sleep;
|
use tokio::time::sleep;
|
||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
|
use crate::{services, utils::HtmlEscape, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}`
|
/// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}`
|
||||||
///
|
///
|
||||||
/// Reports an inappropriate event to homeserver admins
|
/// Reports an inappropriate event to homeserver admins
|
||||||
///
|
pub async fn report_event_route(body: Ruma<report_content::v3::Request>) -> Result<report_content::v3::Response> {
|
||||||
pub async fn report_event_route(
|
|
||||||
body: Ruma<report_content::v3::Request>,
|
|
||||||
) -> Result<report_content::v3::Response> {
|
|
||||||
// user authentication
|
// user authentication
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
|
@ -30,7 +28,7 @@ pub async fn report_event_route(
|
||||||
ErrorKind::NotFound,
|
ErrorKind::NotFound,
|
||||||
"Event ID is not known to us or Event ID is invalid",
|
"Event ID is not known to us or Event ID is invalid",
|
||||||
))
|
))
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// check if the room ID from the URI matches the PDU's room ID
|
// check if the room ID from the URI matches the PDU's room ID
|
||||||
|
@ -71,17 +69,12 @@ pub async fn report_event_route(
|
||||||
));
|
));
|
||||||
};
|
};
|
||||||
|
|
||||||
// send admin room message that we received the report with an @room ping for urgency
|
// send admin room message that we received the report with an @room ping for
|
||||||
services()
|
// urgency
|
||||||
.admin
|
services().admin.send_message(message::RoomMessageEventContent::text_html(
|
||||||
.send_message(message::RoomMessageEventContent::text_html(
|
|
||||||
format!(
|
format!(
|
||||||
"@room Report received from: {}\n\n\
|
"@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \
|
||||||
Event ID: {}\n\
|
Reason: {}",
|
||||||
Room ID: {}\n\
|
|
||||||
Sent By: {}\n\n\
|
|
||||||
Report Score: {}\n\
|
|
||||||
Report Reason: {}",
|
|
||||||
sender_user.to_owned(),
|
sender_user.to_owned(),
|
||||||
pdu.event_id,
|
pdu.event_id,
|
||||||
pdu.room_id,
|
pdu.room_id,
|
||||||
|
@ -105,8 +98,9 @@ pub async fn report_event_route(
|
||||||
),
|
),
|
||||||
));
|
));
|
||||||
|
|
||||||
// even though this is kinda security by obscurity, let's still make a small random delay sending a successful response
|
// even though this is kinda security by obscurity, let's still make a small
|
||||||
// per spec suggestion regarding enumerating for potential events existing in our server.
|
// random delay sending a successful response per spec suggestion regarding
|
||||||
|
// enumerating for potential events existing in our server.
|
||||||
let time_to_wait = rand::thread_rng().gen_range(8..21);
|
let time_to_wait = rand::thread_rng().gen_range(8..21);
|
||||||
debug!(
|
debug!(
|
||||||
"Got successful /report request, waiting {} seconds before sending successful response.",
|
"Got successful /report request, waiting {} seconds before sending successful response.",
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use crate::{
|
use std::{cmp::max, collections::BTreeMap, sync::Arc};
|
||||||
api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma,
|
|
||||||
};
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{
|
api::client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
|
@ -23,13 +22,13 @@ use ruma::{
|
||||||
},
|
},
|
||||||
int,
|
int,
|
||||||
serde::JsonObject,
|
serde::JsonObject,
|
||||||
CanonicalJsonObject, CanonicalJsonValue, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId,
|
CanonicalJsonObject, CanonicalJsonValue, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId,
|
||||||
RoomVersionId,
|
|
||||||
};
|
};
|
||||||
use serde_json::{json, value::to_raw_value};
|
use serde_json::{json, value::to_raw_value};
|
||||||
use std::{cmp::max, collections::BTreeMap, sync::Arc};
|
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
use crate::{api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `POST /_matrix/client/v3/createRoom`
|
/// # `POST /_matrix/client/v3/createRoom`
|
||||||
///
|
///
|
||||||
/// Creates a new room.
|
/// Creates a new room.
|
||||||
|
@ -46,27 +45,19 @@ use tracing::{debug, error, info, warn};
|
||||||
/// - Send events listed in initial state
|
/// - Send events listed in initial state
|
||||||
/// - Send events implied by `name` and `topic`
|
/// - Send events implied by `name` and `topic`
|
||||||
/// - Send invite events
|
/// - Send invite events
|
||||||
pub async fn create_room_route(
|
pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<create_room::v3::Response> {
|
||||||
body: Ruma<create_room::v3::Request>,
|
|
||||||
) -> Result<create_room::v3::Response> {
|
|
||||||
use create_room::v3::RoomPreset;
|
use create_room::v3::RoomPreset;
|
||||||
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if !services().globals.allow_room_creation()
|
if !services().globals.allow_room_creation() && !&body.from_appservice && !services().users.is_admin(sender_user)? {
|
||||||
&& !&body.from_appservice
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Room creation has been disabled."));
|
||||||
&& !services().users.is_admin(sender_user)?
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"Room creation has been disabled.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let room_id: OwnedRoomId;
|
let room_id: OwnedRoomId;
|
||||||
|
|
||||||
// checks if the user specified an explicit (custom) room_id to be created with in request body.
|
// checks if the user specified an explicit (custom) room_id to be created with
|
||||||
// falls back to normal generated room ID if not specified.
|
// in request body. falls back to normal generated room ID if not specified.
|
||||||
if let Some(CanonicalJsonValue::Object(json_body)) = &body.json_body {
|
if let Some(CanonicalJsonValue::Object(json_body)) = &body.json_body {
|
||||||
match json_body.get("room_id") {
|
match json_body.get("room_id") {
|
||||||
Some(custom_room_id) => {
|
Some(custom_room_id) => {
|
||||||
|
@ -76,7 +67,8 @@ pub async fn create_room_route(
|
||||||
if custom_room_id_s.contains(':') {
|
if custom_room_id_s.contains(':') {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Custom room ID contained `:` which is not allowed. Please note that this expects a localpart, not the full room ID.",
|
"Custom room ID contained `:` which is not allowed. Please note that this expects a \
|
||||||
|
localpart, not the full room ID.",
|
||||||
));
|
));
|
||||||
} else if custom_room_id_s.contains(char::is_whitespace) {
|
} else if custom_room_id_s.contains(char::is_whitespace) {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
|
@ -84,41 +76,24 @@ pub async fn create_room_route(
|
||||||
"Custom room ID contained spaces which is not valid.",
|
"Custom room ID contained spaces which is not valid.",
|
||||||
));
|
));
|
||||||
} else if custom_room_id_s.len() > 255 {
|
} else if custom_room_id_s.len() > 255 {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Custom room ID is too long."));
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Custom room ID is too long.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// apply forbidden room alias checks to custom room IDs too
|
// apply forbidden room alias checks to custom room IDs too
|
||||||
if services()
|
if services().globals.forbidden_room_names().is_match(&custom_room_id_s) {
|
||||||
.globals
|
return Err(Error::BadRequest(ErrorKind::Unknown, "Custom room ID is forbidden."));
|
||||||
.forbidden_room_names()
|
|
||||||
.is_match(&custom_room_id_s)
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Unknown,
|
|
||||||
"Custom room ID is forbidden.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let full_room_id = "!".to_owned()
|
let full_room_id = "!".to_owned()
|
||||||
+ &custom_room_id_s.replace('"', "")
|
+ &custom_room_id_s.replace('"', "")
|
||||||
+ ":"
|
+ ":" + services().globals.server_name().as_ref();
|
||||||
+ services().globals.server_name().as_ref();
|
|
||||||
debug!("Full room ID: {}", full_room_id);
|
debug!("Full room ID: {}", full_room_id);
|
||||||
|
|
||||||
room_id = RoomId::parse(full_room_id).map_err(|e| {
|
room_id = RoomId::parse(full_room_id).map_err(|e| {
|
||||||
info!(
|
info!("User attempted to create room with custom room ID but failed parsing: {}", e);
|
||||||
"User attempted to create room with custom room ID but failed parsing: {}",
|
Error::BadRequest(ErrorKind::InvalidParam, "Custom room ID could not be parsed")
|
||||||
e
|
|
||||||
);
|
|
||||||
Error::BadRequest(
|
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Custom room ID could not be parsed",
|
|
||||||
)
|
|
||||||
})?;
|
})?;
|
||||||
}
|
},
|
||||||
None => room_id = RoomId::new(services().globals.server_name()),
|
None => room_id = RoomId::new(services().globals.server_name()),
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
|
@ -135,27 +110,17 @@ pub async fn create_room_route(
|
||||||
|
|
||||||
services().rooms.short.get_or_create_shortroomid(&room_id)?;
|
services().rooms.short.get_or_create_shortroomid(&room_id)?;
|
||||||
|
|
||||||
let mutex_state = Arc::clone(
|
let mutex_state =
|
||||||
services()
|
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
|
||||||
.globals
|
|
||||||
.roomid_mutex_state
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.entry(room_id.clone())
|
|
||||||
.or_default(),
|
|
||||||
);
|
|
||||||
let state_lock = mutex_state.lock().await;
|
let state_lock = mutex_state.lock().await;
|
||||||
|
|
||||||
let alias: Option<OwnedRoomAliasId> =
|
let alias: Option<OwnedRoomAliasId> = body.room_alias_name.as_ref().map_or(Ok(None), |localpart| {
|
||||||
body.room_alias_name
|
|
||||||
.as_ref()
|
|
||||||
.map_or(Ok(None), |localpart| {
|
|
||||||
|
|
||||||
// Basic checks on the room alias validity
|
// Basic checks on the room alias validity
|
||||||
if localpart.contains(':') {
|
if localpart.contains(':') {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Room alias contained `:` which is not allowed. Please note that this expects a localpart, not the full room alias.",
|
"Room alias contained `:` which is not allowed. Please note that this expects a localpart, not the \
|
||||||
|
full room alias.",
|
||||||
));
|
));
|
||||||
} else if localpart.contains(char::is_whitespace) {
|
} else if localpart.contains(char::is_whitespace) {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
|
@ -164,9 +129,9 @@ pub async fn create_room_route(
|
||||||
));
|
));
|
||||||
} else if localpart.len() > 255 {
|
} else if localpart.len() > 255 {
|
||||||
// there is nothing spec-wise saying to check the limit of this,
|
// there is nothing spec-wise saying to check the limit of this,
|
||||||
// however absurdly long room aliases are guaranteed to be unreadable or done maliciously.
|
// however absurdly long room aliases are guaranteed to be unreadable or done
|
||||||
// there is no reason a room alias should even exceed 100 characters as is.
|
// maliciously. there is no reason a room alias should even exceed 100
|
||||||
// generally in spec, 255 is matrix's fav number
|
// characters as is. generally in spec, 255 is matrix's fav number
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
"Room alias is excessively long, clients may not be able to handle this. Please shorten it.",
|
"Room alias is excessively long, clients may not be able to handle this. Please shorten it.",
|
||||||
|
@ -179,37 +144,18 @@ pub async fn create_room_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
// check if room alias is forbidden
|
// check if room alias is forbidden
|
||||||
if services()
|
if services().globals.forbidden_room_names().is_match(localpart) {
|
||||||
.globals
|
return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden."));
|
||||||
.forbidden_room_names()
|
|
||||||
.is_match(localpart)
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Unknown,
|
|
||||||
"Room alias name is forbidden.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let alias = RoomAliasId::parse(format!(
|
let alias =
|
||||||
"#{}:{}",
|
RoomAliasId::parse(format!("#{}:{}", localpart, services().globals.server_name())).map_err(|e| {
|
||||||
localpart,
|
|
||||||
services().globals.server_name()
|
|
||||||
))
|
|
||||||
.map_err(|e| {
|
|
||||||
warn!("Failed to parse room alias for room ID {}: {e}", room_id);
|
warn!("Failed to parse room alias for room ID {}: {e}", room_id);
|
||||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.")
|
Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if services()
|
if services().rooms.alias.resolve_local_alias(&alias)?.is_some() {
|
||||||
.rooms
|
Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists."))
|
||||||
.alias
|
|
||||||
.resolve_local_alias(&alias)?
|
|
||||||
.is_some()
|
|
||||||
{
|
|
||||||
Err(Error::BadRequest(
|
|
||||||
ErrorKind::RoomInUse,
|
|
||||||
"Room alias already exists.",
|
|
||||||
))
|
|
||||||
} else {
|
} else {
|
||||||
Ok(Some(alias))
|
Ok(Some(alias))
|
||||||
}
|
}
|
||||||
|
@ -217,11 +163,7 @@ pub async fn create_room_route(
|
||||||
|
|
||||||
let room_version = match body.room_version.clone() {
|
let room_version = match body.room_version.clone() {
|
||||||
Some(room_version) => {
|
Some(room_version) => {
|
||||||
if services()
|
if services().globals.supported_room_versions().contains(&room_version) {
|
||||||
.globals
|
|
||||||
.supported_room_versions()
|
|
||||||
.contains(&room_version)
|
|
||||||
{
|
|
||||||
room_version
|
room_version
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
|
@ -229,15 +171,13 @@ pub async fn create_room_route(
|
||||||
"This server does not support that room version.",
|
"This server does not support that room version.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
None => services().globals.default_room_version(),
|
None => services().globals.default_room_version(),
|
||||||
};
|
};
|
||||||
|
|
||||||
let content = match &body.creation_content {
|
let content = match &body.creation_content {
|
||||||
Some(content) => {
|
Some(content) => {
|
||||||
let mut content = content
|
let mut content = content.deserialize_as::<CanonicalJsonObject>().map_err(|e| {
|
||||||
.deserialize_as::<CanonicalJsonObject>()
|
|
||||||
.map_err(|e| {
|
|
||||||
error!("Failed to deserialise content as canonical JSON: {}", e);
|
error!("Failed to deserialise content as canonical JSON: {}", e);
|
||||||
Error::bad_database("Failed to deserialise content as canonical JSON.")
|
Error::bad_database("Failed to deserialise content as canonical JSON.")
|
||||||
})?;
|
})?;
|
||||||
|
@ -259,25 +199,25 @@ pub async fn create_room_route(
|
||||||
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")
|
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")
|
||||||
})?,
|
})?,
|
||||||
);
|
);
|
||||||
}
|
},
|
||||||
RoomVersionId::V11 => {} // V11 removed the "creator" key
|
RoomVersionId::V11 => {}, // V11 removed the "creator" key
|
||||||
_ => {
|
_ => {
|
||||||
warn!("Unexpected or unsupported room version {}", room_version);
|
warn!("Unexpected or unsupported room version {}", room_version);
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::BadJson,
|
ErrorKind::BadJson,
|
||||||
"Unexpected or unsupported room version found",
|
"Unexpected or unsupported room version found",
|
||||||
));
|
));
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
content.insert(
|
content.insert(
|
||||||
"room_version".into(),
|
"room_version".into(),
|
||||||
json!(room_version.as_str()).try_into().map_err(|_| {
|
json!(room_version.as_str())
|
||||||
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")
|
.try_into()
|
||||||
})?,
|
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))?,
|
||||||
);
|
);
|
||||||
content
|
content
|
||||||
}
|
},
|
||||||
None => {
|
None => {
|
||||||
// TODO: Add correct value for v11
|
// TODO: Add correct value for v11
|
||||||
let content = match room_version {
|
let content = match room_version {
|
||||||
|
@ -298,7 +238,7 @@ pub async fn create_room_route(
|
||||||
ErrorKind::BadJson,
|
ErrorKind::BadJson,
|
||||||
"Unexpected or unsupported room version found",
|
"Unexpected or unsupported room version found",
|
||||||
));
|
));
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
let mut content = serde_json::from_str::<CanonicalJsonObject>(
|
let mut content = serde_json::from_str::<CanonicalJsonObject>(
|
||||||
to_raw_value(&content)
|
to_raw_value(&content)
|
||||||
|
@ -308,26 +248,20 @@ pub async fn create_room_route(
|
||||||
.unwrap();
|
.unwrap();
|
||||||
content.insert(
|
content.insert(
|
||||||
"room_version".into(),
|
"room_version".into(),
|
||||||
json!(room_version.as_str()).try_into().map_err(|_| {
|
json!(room_version.as_str())
|
||||||
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")
|
.try_into()
|
||||||
})?,
|
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))?,
|
||||||
);
|
);
|
||||||
content
|
content
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
// Validate creation content
|
// Validate creation content
|
||||||
let de_result = serde_json::from_str::<CanonicalJsonObject>(
|
let de_result =
|
||||||
to_raw_value(&content)
|
serde_json::from_str::<CanonicalJsonObject>(to_raw_value(&content).expect("Invalid creation content").get());
|
||||||
.expect("Invalid creation content")
|
|
||||||
.get(),
|
|
||||||
);
|
|
||||||
|
|
||||||
if de_result.is_err() {
|
if de_result.is_err() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"));
|
||||||
ErrorKind::BadJson,
|
|
||||||
"Invalid creation content",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 1. The room create event
|
// 1. The room create event
|
||||||
|
@ -401,9 +335,7 @@ pub async fn create_room_route(
|
||||||
|
|
||||||
if let Some(power_level_content_override) = &body.power_level_content_override {
|
if let Some(power_level_content_override) = &body.power_level_content_override {
|
||||||
let json: JsonObject = serde_json::from_str(power_level_content_override.json().get())
|
let json: JsonObject = serde_json::from_str(power_level_content_override.json().get())
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid power_level_content_override."))?;
|
||||||
Error::BadRequest(ErrorKind::BadJson, "Invalid power_level_content_override.")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
for (key, value) in json {
|
for (key, value) in json {
|
||||||
power_levels_content[key] = value;
|
power_levels_content[key] = value;
|
||||||
|
@ -416,8 +348,7 @@ pub async fn create_room_route(
|
||||||
.build_and_append_pdu(
|
.build_and_append_pdu(
|
||||||
PduBuilder {
|
PduBuilder {
|
||||||
event_type: TimelineEventType::RoomPowerLevels,
|
event_type: TimelineEventType::RoomPowerLevels,
|
||||||
content: to_raw_value(&power_levels_content)
|
content: to_raw_value(&power_levels_content).expect("to_raw_value always works on serde_json::Value"),
|
||||||
.expect("to_raw_value always works on serde_json::Value"),
|
|
||||||
unsigned: None,
|
unsigned: None,
|
||||||
state_key: Some("".to_owned()),
|
state_key: Some("".to_owned()),
|
||||||
redacts: None,
|
redacts: None,
|
||||||
|
@ -484,9 +415,7 @@ pub async fn create_room_route(
|
||||||
.build_and_append_pdu(
|
.build_and_append_pdu(
|
||||||
PduBuilder {
|
PduBuilder {
|
||||||
event_type: TimelineEventType::RoomHistoryVisibility,
|
event_type: TimelineEventType::RoomHistoryVisibility,
|
||||||
content: to_raw_value(&RoomHistoryVisibilityEventContent::new(
|
content: to_raw_value(&RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared))
|
||||||
HistoryVisibility::Shared,
|
|
||||||
))
|
|
||||||
.expect("event is valid, we just created it"),
|
.expect("event is valid, we just created it"),
|
||||||
unsigned: None,
|
unsigned: None,
|
||||||
state_key: Some("".to_owned()),
|
state_key: Some("".to_owned()),
|
||||||
|
@ -531,17 +460,11 @@ pub async fn create_room_route(
|
||||||
pdu_builder.state_key.get_or_insert_with(|| "".to_owned());
|
pdu_builder.state_key.get_or_insert_with(|| "".to_owned());
|
||||||
|
|
||||||
// Silently skip encryption events if they are not allowed
|
// Silently skip encryption events if they are not allowed
|
||||||
if pdu_builder.event_type == TimelineEventType::RoomEncryption
|
if pdu_builder.event_type == TimelineEventType::RoomEncryption && !services().globals.allow_encryption() {
|
||||||
&& !services().globals.allow_encryption()
|
|
||||||
{
|
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
services()
|
services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await?;
|
||||||
.rooms
|
|
||||||
.timeline
|
|
||||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
|
|
||||||
.await?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// 7. Events implied by name and topic
|
// 7. Events implied by name and topic
|
||||||
|
@ -611,26 +534,17 @@ pub async fn create_room_route(
|
||||||
///
|
///
|
||||||
/// Gets a single event.
|
/// Gets a single event.
|
||||||
///
|
///
|
||||||
/// - You have to currently be joined to the room (TODO: Respect history visibility)
|
/// - You have to currently be joined to the room (TODO: Respect history
|
||||||
pub async fn get_room_event_route(
|
/// visibility)
|
||||||
body: Ruma<get_room_event::v3::Request>,
|
pub async fn get_room_event_route(body: Ruma<get_room_event::v3::Request>) -> Result<get_room_event::v3::Response> {
|
||||||
) -> Result<get_room_event::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let event = services()
|
let event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or_else(|| {
|
||||||
.rooms
|
|
||||||
.timeline
|
|
||||||
.get_pdu(&body.event_id)?
|
|
||||||
.ok_or_else(|| {
|
|
||||||
warn!("Event not found, event ID: {:?}", &body.event_id);
|
warn!("Event not found, event ID: {:?}", &body.event_id);
|
||||||
Error::BadRequest(ErrorKind::NotFound, "Event not found.")
|
Error::BadRequest(ErrorKind::NotFound, "Event not found.")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if !services().rooms.state_accessor.user_can_see_event(
|
if !services().rooms.state_accessor.user_can_see_event(sender_user, &event.room_id, &body.event_id)? {
|
||||||
sender_user,
|
|
||||||
&event.room_id,
|
|
||||||
&body.event_id,
|
|
||||||
)? {
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::Forbidden,
|
ErrorKind::Forbidden,
|
||||||
"You don't have permission to view this event.",
|
"You don't have permission to view this event.",
|
||||||
|
@ -649,17 +563,12 @@ pub async fn get_room_event_route(
|
||||||
///
|
///
|
||||||
/// Lists all aliases of the room.
|
/// Lists all aliases of the room.
|
||||||
///
|
///
|
||||||
/// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if history_visibility is world readable
|
/// - Only users joined to the room are allowed to call this TODO: Allow any
|
||||||
pub async fn get_room_aliases_route(
|
/// user to call it if history_visibility is world readable
|
||||||
body: Ruma<aliases::v3::Request>,
|
pub async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) -> Result<aliases::v3::Response> {
|
||||||
) -> Result<aliases::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if !services()
|
if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? {
|
||||||
.rooms
|
|
||||||
.state_cache
|
|
||||||
.is_joined(sender_user, &body.room_id)?
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::Forbidden,
|
ErrorKind::Forbidden,
|
||||||
"You don't have permission to view this room.",
|
"You don't have permission to view this room.",
|
||||||
|
@ -686,16 +595,10 @@ pub async fn get_room_aliases_route(
|
||||||
/// - Transfers some state events
|
/// - Transfers some state events
|
||||||
/// - Moves local aliases
|
/// - Moves local aliases
|
||||||
/// - Modifies old room power levels to prevent users from speaking
|
/// - Modifies old room power levels to prevent users from speaking
|
||||||
pub async fn upgrade_room_route(
|
pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result<upgrade_room::v3::Response> {
|
||||||
body: Ruma<upgrade_room::v3::Request>,
|
|
||||||
) -> Result<upgrade_room::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if !services()
|
if !services().globals.supported_room_versions().contains(&body.new_version) {
|
||||||
.globals
|
|
||||||
.supported_room_versions()
|
|
||||||
.contains(&body.new_version)
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::UnsupportedRoomVersion,
|
ErrorKind::UnsupportedRoomVersion,
|
||||||
"This server does not support that room version.",
|
"This server does not support that room version.",
|
||||||
|
@ -704,24 +607,15 @@ pub async fn upgrade_room_route(
|
||||||
|
|
||||||
// Create a replacement room
|
// Create a replacement room
|
||||||
let replacement_room = RoomId::new(services().globals.server_name());
|
let replacement_room = RoomId::new(services().globals.server_name());
|
||||||
services()
|
services().rooms.short.get_or_create_shortroomid(&replacement_room)?;
|
||||||
.rooms
|
|
||||||
.short
|
|
||||||
.get_or_create_shortroomid(&replacement_room)?;
|
|
||||||
|
|
||||||
let mutex_state = Arc::clone(
|
let mutex_state =
|
||||||
services()
|
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
|
||||||
.globals
|
|
||||||
.roomid_mutex_state
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.entry(body.room_id.clone())
|
|
||||||
.or_default(),
|
|
||||||
);
|
|
||||||
let state_lock = mutex_state.lock().await;
|
let state_lock = mutex_state.lock().await;
|
||||||
|
|
||||||
// Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further
|
// Send a m.room.tombstone event to the old room to indicate that it is not
|
||||||
// Fail if the sender does not have the required permissions
|
// intended to be used any further Fail if the sender does not have the required
|
||||||
|
// permissions
|
||||||
let tombstone_event_id = services()
|
let tombstone_event_id = services()
|
||||||
.rooms
|
.rooms
|
||||||
.timeline
|
.timeline
|
||||||
|
@ -745,15 +639,8 @@ pub async fn upgrade_room_route(
|
||||||
|
|
||||||
// Change lock to replacement room
|
// Change lock to replacement room
|
||||||
drop(state_lock);
|
drop(state_lock);
|
||||||
let mutex_state = Arc::clone(
|
let mutex_state =
|
||||||
services()
|
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(replacement_room.clone()).or_default());
|
||||||
.globals
|
|
||||||
.roomid_mutex_state
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.entry(replacement_room.clone())
|
|
||||||
.or_default(),
|
|
||||||
);
|
|
||||||
let state_lock = mutex_state.lock().await;
|
let state_lock = mutex_state.lock().await;
|
||||||
|
|
||||||
// Get the old room creation event
|
// Get the old room creation event
|
||||||
|
@ -774,7 +661,8 @@ pub async fn upgrade_room_route(
|
||||||
(*tombstone_event_id).to_owned(),
|
(*tombstone_event_id).to_owned(),
|
||||||
));
|
));
|
||||||
|
|
||||||
// Send a m.room.create event containing a predecessor field and the applicable room_version
|
// Send a m.room.create event containing a predecessor field and the applicable
|
||||||
|
// room_version
|
||||||
match body.new_version {
|
match body.new_version {
|
||||||
RoomVersionId::V1
|
RoomVersionId::V1
|
||||||
| RoomVersionId::V2
|
| RoomVersionId::V2
|
||||||
|
@ -793,21 +681,18 @@ pub async fn upgrade_room_route(
|
||||||
Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")
|
Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")
|
||||||
})?,
|
})?,
|
||||||
);
|
);
|
||||||
}
|
},
|
||||||
RoomVersionId::V11 => {
|
RoomVersionId::V11 => {
|
||||||
// "creator" key no longer exists in V11 rooms
|
// "creator" key no longer exists in V11 rooms
|
||||||
create_event_content.remove("creator");
|
create_event_content.remove("creator");
|
||||||
}
|
},
|
||||||
_ => {
|
_ => {
|
||||||
warn!(
|
warn!("Unexpected or unsupported room version {}", body.new_version);
|
||||||
"Unexpected or unsupported room version {}",
|
|
||||||
body.new_version
|
|
||||||
);
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::BadJson,
|
ErrorKind::BadJson,
|
||||||
"Unexpected or unsupported room version found",
|
"Unexpected or unsupported room version found",
|
||||||
));
|
));
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
create_event_content.insert(
|
create_event_content.insert(
|
||||||
|
@ -825,16 +710,11 @@ pub async fn upgrade_room_route(
|
||||||
|
|
||||||
// Validate creation event content
|
// Validate creation event content
|
||||||
let de_result = serde_json::from_str::<CanonicalJsonObject>(
|
let de_result = serde_json::from_str::<CanonicalJsonObject>(
|
||||||
to_raw_value(&create_event_content)
|
to_raw_value(&create_event_content).expect("Error forming creation event").get(),
|
||||||
.expect("Error forming creation event")
|
|
||||||
.get(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
if de_result.is_err() {
|
if de_result.is_err() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"));
|
||||||
ErrorKind::BadJson,
|
|
||||||
"Error forming creation event",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
services()
|
services()
|
||||||
|
@ -843,8 +723,7 @@ pub async fn upgrade_room_route(
|
||||||
.build_and_append_pdu(
|
.build_and_append_pdu(
|
||||||
PduBuilder {
|
PduBuilder {
|
||||||
event_type: TimelineEventType::RoomCreate,
|
event_type: TimelineEventType::RoomCreate,
|
||||||
content: to_raw_value(&create_event_content)
|
content: to_raw_value(&create_event_content).expect("event is valid, we just created it"),
|
||||||
.expect("event is valid, we just created it"),
|
|
||||||
unsigned: None,
|
unsigned: None,
|
||||||
state_key: Some("".to_owned()),
|
state_key: Some("".to_owned()),
|
||||||
redacts: None,
|
redacts: None,
|
||||||
|
@ -898,12 +777,7 @@ pub async fn upgrade_room_route(
|
||||||
|
|
||||||
// Replicate transferable state events to the new room
|
// Replicate transferable state events to the new room
|
||||||
for event_type in transferable_state_events {
|
for event_type in transferable_state_events {
|
||||||
let event_content =
|
let event_content = match services().rooms.state_accessor.room_state_get(&body.room_id, &event_type, "")? {
|
||||||
match services()
|
|
||||||
.rooms
|
|
||||||
.state_accessor
|
|
||||||
.room_state_get(&body.room_id, &event_type, "")?
|
|
||||||
{
|
|
||||||
Some(v) => v.content.clone(),
|
Some(v) => v.content.clone(),
|
||||||
None => continue, // Skipping missing events.
|
None => continue, // Skipping missing events.
|
||||||
};
|
};
|
||||||
|
@ -927,16 +801,8 @@ pub async fn upgrade_room_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
// Moves any local aliases to the new room
|
// Moves any local aliases to the new room
|
||||||
for alias in services()
|
for alias in services().rooms.alias.local_aliases_for_room(&body.room_id).filter_map(std::result::Result::ok) {
|
||||||
.rooms
|
services().rooms.alias.set_alias(&alias, &replacement_room)?;
|
||||||
.alias
|
|
||||||
.local_aliases_for_room(&body.room_id)
|
|
||||||
.filter_map(std::result::Result::ok)
|
|
||||||
{
|
|
||||||
services()
|
|
||||||
.rooms
|
|
||||||
.alias
|
|
||||||
.set_alias(&alias, &replacement_room)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get the old room power levels
|
// Get the old room power levels
|
||||||
|
@ -956,15 +822,15 @@ pub async fn upgrade_room_route(
|
||||||
power_levels_event_content.events_default = new_level;
|
power_levels_event_content.events_default = new_level;
|
||||||
power_levels_event_content.invite = new_level;
|
power_levels_event_content.invite = new_level;
|
||||||
|
|
||||||
// Modify the power levels in the old room to prevent sending of events and inviting new users
|
// Modify the power levels in the old room to prevent sending of events and
|
||||||
|
// inviting new users
|
||||||
let _ = services()
|
let _ = services()
|
||||||
.rooms
|
.rooms
|
||||||
.timeline
|
.timeline
|
||||||
.build_and_append_pdu(
|
.build_and_append_pdu(
|
||||||
PduBuilder {
|
PduBuilder {
|
||||||
event_type: TimelineEventType::RoomPowerLevels,
|
event_type: TimelineEventType::RoomPowerLevels,
|
||||||
content: to_raw_value(&power_levels_event_content)
|
content: to_raw_value(&power_levels_event_content).expect("event is valid, we just created it"),
|
||||||
.expect("event is valid, we just created it"),
|
|
||||||
unsigned: None,
|
unsigned: None,
|
||||||
state_key: Some("".to_owned()),
|
state_key: Some("".to_owned()),
|
||||||
redacts: None,
|
redacts: None,
|
||||||
|
@ -978,5 +844,7 @@ pub async fn upgrade_room_route(
|
||||||
drop(state_lock);
|
drop(state_lock);
|
||||||
|
|
||||||
// Return the replacement room id
|
// Return the replacement room id
|
||||||
Ok(upgrade_room::v3::Response { replacement_room })
|
Ok(upgrade_room::v3::Response {
|
||||||
|
replacement_room,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,4 +1,5 @@
|
||||||
use crate::{services, Error, Result, Ruma};
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use ruma::api::client::{
|
use ruma::api::client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
search::search_events::{
|
search::search_events::{
|
||||||
|
@ -7,28 +8,22 @@ use ruma::api::client::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
use std::collections::BTreeMap;
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `POST /_matrix/client/r0/search`
|
/// # `POST /_matrix/client/r0/search`
|
||||||
///
|
///
|
||||||
/// Searches rooms for messages.
|
/// Searches rooms for messages.
|
||||||
///
|
///
|
||||||
/// - Only works if the user is currently joined to the room (TODO: Respect history visibility)
|
/// - Only works if the user is currently joined to the room (TODO: Respect
|
||||||
pub async fn search_events_route(
|
/// history visibility)
|
||||||
body: Ruma<search_events::v3::Request>,
|
pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Result<search_events::v3::Response> {
|
||||||
) -> Result<search_events::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let search_criteria = body.search_categories.room_events.as_ref().unwrap();
|
let search_criteria = body.search_categories.room_events.as_ref().unwrap();
|
||||||
let filter = &search_criteria.filter;
|
let filter = &search_criteria.filter;
|
||||||
|
|
||||||
let room_ids = filter.rooms.clone().unwrap_or_else(|| {
|
let room_ids = filter.rooms.clone().unwrap_or_else(|| {
|
||||||
services()
|
services().rooms.state_cache.rooms_joined(sender_user).filter_map(std::result::Result::ok).collect()
|
||||||
.rooms
|
|
||||||
.state_cache
|
|
||||||
.rooms_joined(sender_user)
|
|
||||||
.filter_map(std::result::Result::ok)
|
|
||||||
.collect()
|
|
||||||
});
|
});
|
||||||
|
|
||||||
// Use limit or else 10, with maximum 100
|
// Use limit or else 10, with maximum 100
|
||||||
|
@ -37,34 +32,21 @@ pub async fn search_events_route(
|
||||||
let mut searches = Vec::new();
|
let mut searches = Vec::new();
|
||||||
|
|
||||||
for room_id in room_ids {
|
for room_id in room_ids {
|
||||||
if !services()
|
if !services().rooms.state_cache.is_joined(sender_user, &room_id)? {
|
||||||
.rooms
|
|
||||||
.state_cache
|
|
||||||
.is_joined(sender_user, &room_id)?
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::Forbidden,
|
ErrorKind::Forbidden,
|
||||||
"You don't have permission to view this room.",
|
"You don't have permission to view this room.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Some(search) = services()
|
if let Some(search) = services().rooms.search.search_pdus(&room_id, &search_criteria.search_term)? {
|
||||||
.rooms
|
|
||||||
.search
|
|
||||||
.search_pdus(&room_id, &search_criteria.search_term)?
|
|
||||||
{
|
|
||||||
searches.push(search.0.peekable());
|
searches.push(search.0.peekable());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let skip = match body.next_batch.as_ref().map(|s| s.parse()) {
|
let skip = match body.next_batch.as_ref().map(|s| s.parse()) {
|
||||||
Some(Ok(s)) => s,
|
Some(Ok(s)) => s,
|
||||||
Some(Err(_)) => {
|
Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")),
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"Invalid next_batch token.",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
None => 0, // Default to the start
|
None => 0, // Default to the start
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,3 @@
|
||||||
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
|
||||||
use crate::{services, utils, Error, Result, Ruma};
|
|
||||||
use argon2::{PasswordHash, PasswordVerifier};
|
use argon2::{PasswordHash, PasswordVerifier};
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{
|
api::client::{
|
||||||
|
@ -22,6 +20,9 @@ use ruma::{
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
||||||
|
use crate::{services, utils, Error, Result, Ruma};
|
||||||
|
|
||||||
#[derive(Debug, Deserialize)]
|
#[derive(Debug, Deserialize)]
|
||||||
struct Claims {
|
struct Claims {
|
||||||
sub: String,
|
sub: String,
|
||||||
|
@ -30,11 +31,9 @@ struct Claims {
|
||||||
|
|
||||||
/// # `GET /_matrix/client/v3/login`
|
/// # `GET /_matrix/client/v3/login`
|
||||||
///
|
///
|
||||||
/// Get the supported login types of this server. One of these should be used as the `type` field
|
/// Get the supported login types of this server. One of these should be used as
|
||||||
/// when logging in.
|
/// the `type` field when logging in.
|
||||||
pub async fn get_login_types_route(
|
pub async fn get_login_types_route(_body: Ruma<get_login_types::v3::Request>) -> Result<get_login_types::v3::Response> {
|
||||||
_body: Ruma<get_login_types::v3::Request>,
|
|
||||||
) -> Result<get_login_types::v3::Response> {
|
|
||||||
Ok(get_login_types::v3::Response::new(vec to see
|
/// Note: You can use [`GET
|
||||||
|
/// /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see
|
||||||
/// supported login types.
|
/// supported login types.
|
||||||
pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> {
|
pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> {
|
||||||
// Validate login method
|
// Validate login method
|
||||||
|
@ -68,16 +70,19 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
||||||
debug!("Using username from identifier field");
|
debug!("Using username from identifier field");
|
||||||
user_id.to_lowercase()
|
user_id.to_lowercase()
|
||||||
} else if let Some(user_id) = user {
|
} else if let Some(user_id) = user {
|
||||||
warn!("User \"{}\" is attempting to login with the deprecated \"user\" field at \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is destined to be removed in a future Matrix release.", user_id);
|
warn!(
|
||||||
|
"User \"{}\" is attempting to login with the deprecated \"user\" field at \
|
||||||
|
\"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is \
|
||||||
|
destined to be removed in a future Matrix release.",
|
||||||
|
user_id
|
||||||
|
);
|
||||||
user_id.to_lowercase()
|
user_id.to_lowercase()
|
||||||
} else {
|
} else {
|
||||||
warn!("Bad login type: {:?}", &body.login_info);
|
warn!("Bad login type: {:?}", &body.login_info);
|
||||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
||||||
};
|
};
|
||||||
|
|
||||||
let user_id =
|
let user_id = UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|
||||||
UserId::parse_with_server_name(username, services().globals.server_name())
|
|
||||||
.map_err(|e| {
|
|
||||||
warn!("Failed to parse username from user logging in: {}", e);
|
warn!("Failed to parse username from user logging in: {}", e);
|
||||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||||
})?;
|
})?;
|
||||||
|
@ -85,16 +90,10 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
||||||
let hash = services()
|
let hash = services()
|
||||||
.users
|
.users
|
||||||
.password_hash(&user_id)?
|
.password_hash(&user_id)?
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password."))?;
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"Wrong username or password.",
|
|
||||||
))?;
|
|
||||||
|
|
||||||
if hash.is_empty() {
|
if hash.is_empty() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated"));
|
||||||
ErrorKind::UserDeactivated,
|
|
||||||
"The user has been deactivated",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let Ok(parsed_hash) = PasswordHash::new(&hash) else {
|
let Ok(parsed_hash) = PasswordHash::new(&hash) else {
|
||||||
|
@ -102,29 +101,21 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
||||||
return Err(Error::BadServerResponse("could not hash"));
|
return Err(Error::BadServerResponse("could not hash"));
|
||||||
};
|
};
|
||||||
|
|
||||||
let hash_matches = services()
|
let hash_matches = services().globals.argon.verify_password(password.as_bytes(), &parsed_hash).is_ok();
|
||||||
.globals
|
|
||||||
.argon
|
|
||||||
.verify_password(password.as_bytes(), &parsed_hash)
|
|
||||||
.is_ok();
|
|
||||||
|
|
||||||
if !hash_matches {
|
if !hash_matches {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password."));
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"Wrong username or password.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
user_id
|
user_id
|
||||||
}
|
},
|
||||||
login::v3::LoginInfo::Token(login::v3::Token { token }) => {
|
login::v3::LoginInfo::Token(login::v3::Token {
|
||||||
|
token,
|
||||||
|
}) => {
|
||||||
debug!("Got token login type");
|
debug!("Got token login type");
|
||||||
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
|
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
|
||||||
let token = jsonwebtoken::decode::<Claims>(
|
let token =
|
||||||
token,
|
jsonwebtoken::decode::<Claims>(token, jwt_decoding_key, &jsonwebtoken::Validation::default())
|
||||||
jwt_decoding_key,
|
|
||||||
&jsonwebtoken::Validation::default(),
|
|
||||||
)
|
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
warn!("Failed to parse JWT token from user logging in: {}", e);
|
warn!("Failed to parse JWT token from user logging in: {}", e);
|
||||||
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
|
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
|
||||||
|
@ -132,19 +123,17 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
||||||
|
|
||||||
let username = token.claims.sub.to_lowercase();
|
let username = token.claims.sub.to_lowercase();
|
||||||
|
|
||||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|
||||||
|e| {
|
|
||||||
warn!("Failed to parse username from user logging in: {}", e);
|
warn!("Failed to parse username from user logging in: {}", e);
|
||||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||||
},
|
})?
|
||||||
)?
|
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::Unknown,
|
ErrorKind::Unknown,
|
||||||
"Token login is not supported (server has no jwt decoding key).",
|
"Token login is not supported (server has no jwt decoding key).",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
|
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
|
||||||
identifier,
|
identifier,
|
||||||
|
@ -152,79 +141,65 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
||||||
}) => {
|
}) => {
|
||||||
debug!("Got appservice login type");
|
debug!("Got appservice login type");
|
||||||
if !body.from_appservice {
|
if !body.from_appservice {
|
||||||
info!("User tried logging in as an appservice, but request body is not from a known/registered appservice");
|
info!(
|
||||||
return Err(Error::BadRequest(
|
"User tried logging in as an appservice, but request body is not from a known/registered \
|
||||||
ErrorKind::Forbidden,
|
appservice"
|
||||||
"Forbidden login type.",
|
);
|
||||||
));
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Forbidden login type."));
|
||||||
};
|
};
|
||||||
let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
|
let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
|
||||||
user_id.to_lowercase()
|
user_id.to_lowercase()
|
||||||
} else if let Some(user_id) = user {
|
} else if let Some(user_id) = user {
|
||||||
warn!("Appservice \"{}\" is attempting to login with the deprecated \"user\" field at \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is destined to be removed in a future Matrix release.", user_id);
|
warn!(
|
||||||
|
"Appservice \"{}\" is attempting to login with the deprecated \"user\" field at \
|
||||||
|
\"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is \
|
||||||
|
destined to be removed in a future Matrix release.",
|
||||||
|
user_id
|
||||||
|
);
|
||||||
user_id.to_lowercase()
|
user_id.to_lowercase()
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
||||||
};
|
};
|
||||||
|
|
||||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|
||||||
|e| {
|
|
||||||
warn!("Failed to parse username from appservice logging in: {}", e);
|
warn!("Failed to parse username from appservice logging in: {}", e);
|
||||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||||
|
})?
|
||||||
},
|
},
|
||||||
)?
|
|
||||||
}
|
|
||||||
_ => {
|
_ => {
|
||||||
warn!("Unsupported or unknown login type: {:?}", &body.login_info);
|
warn!("Unsupported or unknown login type: {:?}", &body.login_info);
|
||||||
debug!("JSON body: {:?}", &body.json_body);
|
debug!("JSON body: {:?}", &body.json_body);
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported or unknown login type."));
|
||||||
ErrorKind::Unknown,
|
},
|
||||||
"Unsupported or unknown login type.",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
// Generate new device id if the user didn't specify one
|
// Generate new device id if the user didn't specify one
|
||||||
let device_id = body
|
let device_id = body.device_id.clone().unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
|
||||||
.device_id
|
|
||||||
.clone()
|
|
||||||
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
|
|
||||||
|
|
||||||
// Generate a new token for the device
|
// Generate a new token for the device
|
||||||
let token = utils::random_string(TOKEN_LENGTH);
|
let token = utils::random_string(TOKEN_LENGTH);
|
||||||
|
|
||||||
// Determine if device_id was provided and exists in the db for this user
|
// Determine if device_id was provided and exists in the db for this user
|
||||||
let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
|
let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
|
||||||
services()
|
services().users.all_device_ids(&user_id).any(|x| x.as_ref().map_or(false, |v| v == device_id))
|
||||||
.users
|
|
||||||
.all_device_ids(&user_id)
|
|
||||||
.any(|x| x.as_ref().map_or(false, |v| v == device_id))
|
|
||||||
});
|
});
|
||||||
|
|
||||||
if device_exists {
|
if device_exists {
|
||||||
services().users.set_token(&user_id, &device_id, &token)?;
|
services().users.set_token(&user_id, &device_id, &token)?;
|
||||||
} else {
|
} else {
|
||||||
services().users.create_device(
|
services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
|
||||||
&user_id,
|
|
||||||
&device_id,
|
|
||||||
&token,
|
|
||||||
body.initial_device_display_name.clone(),
|
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// send client well-known if specified so the client knows to reconfigure itself
|
// send client well-known if specified so the client knows to reconfigure itself
|
||||||
let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new(
|
let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new(
|
||||||
services()
|
services().globals.well_known_client().to_owned().unwrap_or_else(|| "".to_owned()),
|
||||||
.globals
|
|
||||||
.well_known_client()
|
|
||||||
.to_owned()
|
|
||||||
.unwrap_or_else(|| "".to_owned()),
|
|
||||||
));
|
));
|
||||||
|
|
||||||
info!("{} logged in", user_id);
|
info!("{} logged in", user_id);
|
||||||
|
|
||||||
// home_server is deprecated but apparently must still be sent despite it being deprecated over 6 years ago.
|
// home_server is deprecated but apparently must still be sent despite it being
|
||||||
// initially i thought this macro was unnecessary, but ruma uses this same macro for the same reason so...
|
// deprecated over 6 years ago. initially i thought this macro was unnecessary,
|
||||||
|
// but ruma uses this same macro for the same reason so...
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
Ok(login::v3::Response {
|
Ok(login::v3::Response {
|
||||||
user_id,
|
user_id,
|
||||||
|
@ -248,7 +223,8 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
||||||
/// Log out the current device.
|
/// Log out the current device.
|
||||||
///
|
///
|
||||||
/// - Invalidates access token
|
/// - Invalidates access token
|
||||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||||
|
/// last seen ts)
|
||||||
/// - Forgets to-device events
|
/// - Forgets to-device events
|
||||||
/// - Triggers device list updates
|
/// - Triggers device list updates
|
||||||
pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> {
|
pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> {
|
||||||
|
@ -268,15 +244,15 @@ pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3:
|
||||||
/// Log out all devices of this user.
|
/// Log out all devices of this user.
|
||||||
///
|
///
|
||||||
/// - Invalidates all access tokens
|
/// - Invalidates all access tokens
|
||||||
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts)
|
/// - Deletes all device metadata (device id, device display name, last seen ip,
|
||||||
|
/// last seen ts)
|
||||||
/// - Forgets all to-device events
|
/// - Forgets all to-device events
|
||||||
/// - Triggers device list updates
|
/// - Triggers device list updates
|
||||||
///
|
///
|
||||||
/// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](fn.logout_route.html)
|
/// Note: This is equivalent to calling [`GET
|
||||||
/// from each device of this user.
|
/// /_matrix/client/r0/logout`](fn.logout_route.html) from each device of this
|
||||||
pub async fn logout_all_route(
|
/// user.
|
||||||
body: Ruma<logout_all::v3::Request>,
|
pub async fn logout_all_route(body: Ruma<logout_all::v3::Request>) -> Result<logout_all::v3::Response> {
|
||||||
) -> Result<logout_all::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
for device_id in services().users.all_device_ids(sender_user).flatten() {
|
for device_id in services().users.all_device_ids(sender_user).flatten() {
|
||||||
|
|
|
@ -1,34 +1,19 @@
|
||||||
use crate::{services, Result, Ruma};
|
|
||||||
use ruma::api::client::space::get_hierarchy;
|
use ruma::api::client::space::get_hierarchy;
|
||||||
|
|
||||||
|
use crate::{services, Result, Ruma};
|
||||||
|
|
||||||
/// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy``
|
/// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy``
|
||||||
///
|
///
|
||||||
/// Paginates over the space tree in a depth-first manner to locate child rooms of a given space.
|
/// Paginates over the space tree in a depth-first manner to locate child rooms
|
||||||
pub async fn get_hierarchy_route(
|
/// of a given space.
|
||||||
body: Ruma<get_hierarchy::v1::Request>,
|
pub async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> {
|
||||||
) -> Result<get_hierarchy::v1::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let skip = body
|
let skip = body.from.as_ref().and_then(|s| s.parse::<usize>().ok()).unwrap_or(0);
|
||||||
.from
|
|
||||||
.as_ref()
|
|
||||||
.and_then(|s| s.parse::<usize>().ok())
|
|
||||||
.unwrap_or(0);
|
|
||||||
|
|
||||||
let limit = body.limit.map_or(10, u64::from).min(100) as usize;
|
let limit = body.limit.map_or(10, u64::from).min(100) as usize;
|
||||||
|
|
||||||
let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself
|
let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself
|
||||||
|
|
||||||
services()
|
services().rooms.spaces.get_hierarchy(sender_user, &body.room_id, limit, skip, max_depth, body.suggested_only).await
|
||||||
.rooms
|
|
||||||
.spaces
|
|
||||||
.get_hierarchy(
|
|
||||||
sender_user,
|
|
||||||
&body.room_id,
|
|
||||||
limit,
|
|
||||||
skip,
|
|
||||||
max_depth,
|
|
||||||
body.suggested_only,
|
|
||||||
)
|
|
||||||
.await
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,25 +1,25 @@
|
||||||
use std::sync::Arc;
|
use std::sync::Arc;
|
||||||
|
|
||||||
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::{
|
api::client::{
|
||||||
error::ErrorKind,
|
error::ErrorKind,
|
||||||
state::{get_state_events, get_state_events_for_key, send_state_event},
|
state::{get_state_events, get_state_events_for_key, send_state_event},
|
||||||
},
|
},
|
||||||
events::{
|
events::{room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType},
|
||||||
room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType,
|
|
||||||
},
|
|
||||||
serde::Raw,
|
serde::Raw,
|
||||||
EventId, RoomId, UserId,
|
EventId, RoomId, UserId,
|
||||||
};
|
};
|
||||||
use tracing::{error, log::warn};
|
use tracing::{error, log::warn};
|
||||||
|
|
||||||
|
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}`
|
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}`
|
||||||
///
|
///
|
||||||
/// Sends a state event into the room.
|
/// Sends a state event into the room.
|
||||||
///
|
///
|
||||||
/// - The only requirement for the content is that it has to be valid json
|
/// - The only requirement for the content is that it has to be valid json
|
||||||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
/// - Tries to send the event into the room, auth rules will determine if it is
|
||||||
|
/// allowed
|
||||||
/// - If event is new canonical_alias: Rejects if alias is incorrect
|
/// - If event is new canonical_alias: Rejects if alias is incorrect
|
||||||
pub async fn send_state_event_for_key_route(
|
pub async fn send_state_event_for_key_route(
|
||||||
body: Ruma<send_state_event::v3::Request>,
|
body: Ruma<send_state_event::v3::Request>,
|
||||||
|
@ -36,7 +36,9 @@ pub async fn send_state_event_for_key_route(
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let event_id = (*event_id).to_owned();
|
let event_id = (*event_id).to_owned();
|
||||||
Ok(send_state_event::v3::Response { event_id })
|
Ok(send_state_event::v3::Response {
|
||||||
|
event_id,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}`
|
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}`
|
||||||
|
@ -44,7 +46,8 @@ pub async fn send_state_event_for_key_route(
|
||||||
/// Sends a state event into the room.
|
/// Sends a state event into the room.
|
||||||
///
|
///
|
||||||
/// - The only requirement for the content is that it has to be valid json
|
/// - The only requirement for the content is that it has to be valid json
|
||||||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
/// - Tries to send the event into the room, auth rules will determine if it is
|
||||||
|
/// allowed
|
||||||
/// - If event is new canonical_alias: Rejects if alias is incorrect
|
/// - If event is new canonical_alias: Rejects if alias is incorrect
|
||||||
pub async fn send_state_event_for_empty_key_route(
|
pub async fn send_state_event_for_empty_key_route(
|
||||||
body: Ruma<send_state_event::v3::Request>,
|
body: Ruma<send_state_event::v3::Request>,
|
||||||
|
@ -53,10 +56,7 @@ pub async fn send_state_event_for_empty_key_route(
|
||||||
|
|
||||||
// Forbid m.room.encryption if encryption is disabled
|
// Forbid m.room.encryption if encryption is disabled
|
||||||
if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() {
|
if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Encryption has been disabled"));
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"Encryption has been disabled",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let event_id = send_state_event_for_key_helper(
|
let event_id = send_state_event_for_key_helper(
|
||||||
|
@ -69,24 +69,24 @@ pub async fn send_state_event_for_empty_key_route(
|
||||||
.await?;
|
.await?;
|
||||||
|
|
||||||
let event_id = (*event_id).to_owned();
|
let event_id = (*event_id).to_owned();
|
||||||
Ok(send_state_event::v3::Response { event_id }.into())
|
Ok(send_state_event::v3::Response {
|
||||||
|
event_id,
|
||||||
|
}
|
||||||
|
.into())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/rooms/{roomid}/state`
|
/// # `GET /_matrix/client/r0/rooms/{roomid}/state`
|
||||||
///
|
///
|
||||||
/// Get all state events for a room.
|
/// Get all state events for a room.
|
||||||
///
|
///
|
||||||
/// - If not joined: Only works if current room history visibility is world readable
|
/// - If not joined: Only works if current room history visibility is world
|
||||||
|
/// readable
|
||||||
pub async fn get_state_events_route(
|
pub async fn get_state_events_route(
|
||||||
body: Ruma<get_state_events::v3::Request>,
|
body: Ruma<get_state_events::v3::Request>,
|
||||||
) -> Result<get_state_events::v3::Response> {
|
) -> Result<get_state_events::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if !services()
|
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
|
||||||
.rooms
|
|
||||||
.state_accessor
|
|
||||||
.user_can_see_state_events(sender_user, &body.room_id)?
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::Forbidden,
|
ErrorKind::Forbidden,
|
||||||
"You don't have permission to view the room state.",
|
"You don't have permission to view the room state.",
|
||||||
|
@ -108,42 +108,31 @@ pub async fn get_state_events_route(
|
||||||
/// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}/{stateKey}`
|
/// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}/{stateKey}`
|
||||||
///
|
///
|
||||||
/// Get single state event of a room with the specified state key.
|
/// Get single state event of a room with the specified state key.
|
||||||
/// The optional query parameter `?format=event|content` allows returning the full room state event
|
/// The optional query parameter `?format=event|content` allows returning the
|
||||||
/// or just the state event's content (default behaviour)
|
/// full room state event or just the state event's content (default behaviour)
|
||||||
///
|
///
|
||||||
/// - If not joined: Only works if current room history visibility is world readable
|
/// - If not joined: Only works if current room history visibility is world
|
||||||
|
/// readable
|
||||||
pub async fn get_state_events_for_key_route(
|
pub async fn get_state_events_for_key_route(
|
||||||
body: Ruma<get_state_events_for_key::v3::Request>,
|
body: Ruma<get_state_events_for_key::v3::Request>,
|
||||||
) -> Result<get_state_events_for_key::v3::Response> {
|
) -> Result<get_state_events_for_key::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if !services()
|
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
|
||||||
.rooms
|
|
||||||
.state_accessor
|
|
||||||
.user_can_see_state_events(sender_user, &body.room_id)?
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::Forbidden,
|
ErrorKind::Forbidden,
|
||||||
"You don't have permission to view the room state.",
|
"You don't have permission to view the room state.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let event = services()
|
let event =
|
||||||
.rooms
|
services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, &body.state_key)?.ok_or_else(
|
||||||
.state_accessor
|
|| {
|
||||||
.room_state_get(&body.room_id, &body.event_type, &body.state_key)?
|
warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
|
||||||
.ok_or_else(|| {
|
|
||||||
warn!(
|
|
||||||
"State event {:?} not found in room {:?}",
|
|
||||||
&body.event_type, &body.room_id
|
|
||||||
);
|
|
||||||
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
|
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
|
||||||
})?;
|
},
|
||||||
if body
|
)?;
|
||||||
.format
|
if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) {
|
||||||
.as_ref()
|
|
||||||
.is_some_and(|f| f.to_lowercase().eq("event"))
|
|
||||||
{
|
|
||||||
Ok(get_state_events_for_key::v3::Response {
|
Ok(get_state_events_for_key::v3::Response {
|
||||||
content: None,
|
content: None,
|
||||||
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
|
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
|
||||||
|
@ -165,43 +154,30 @@ pub async fn get_state_events_for_key_route(
|
||||||
/// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}`
|
/// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}`
|
||||||
///
|
///
|
||||||
/// Get single state event of a room.
|
/// Get single state event of a room.
|
||||||
/// The optional query parameter `?format=event|content` allows returning the full room state event
|
/// The optional query parameter `?format=event|content` allows returning the
|
||||||
/// or just the state event's content (default behaviour)
|
/// full room state event or just the state event's content (default behaviour)
|
||||||
///
|
///
|
||||||
/// - If not joined: Only works if current room history visibility is world readable
|
/// - If not joined: Only works if current room history visibility is world
|
||||||
|
/// readable
|
||||||
pub async fn get_state_events_for_empty_key_route(
|
pub async fn get_state_events_for_empty_key_route(
|
||||||
body: Ruma<get_state_events_for_key::v3::Request>,
|
body: Ruma<get_state_events_for_key::v3::Request>,
|
||||||
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
|
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if !services()
|
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
|
||||||
.rooms
|
|
||||||
.state_accessor
|
|
||||||
.user_can_see_state_events(sender_user, &body.room_id)?
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::Forbidden,
|
ErrorKind::Forbidden,
|
||||||
"You don't have permission to view the room state.",
|
"You don't have permission to view the room state.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
let event = services()
|
let event =
|
||||||
.rooms
|
services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, "")?.ok_or_else(|| {
|
||||||
.state_accessor
|
warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
|
||||||
.room_state_get(&body.room_id, &body.event_type, "")?
|
|
||||||
.ok_or_else(|| {
|
|
||||||
warn!(
|
|
||||||
"State event {:?} not found in room {:?}",
|
|
||||||
&body.event_type, &body.room_id
|
|
||||||
);
|
|
||||||
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
|
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
if body
|
if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) {
|
||||||
.format
|
|
||||||
.as_ref()
|
|
||||||
.is_some_and(|f| f.to_lowercase().eq("event"))
|
|
||||||
{
|
|
||||||
Ok(get_state_events_for_key::v3::Response {
|
Ok(get_state_events_for_key::v3::Response {
|
||||||
content: None,
|
content: None,
|
||||||
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
|
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
|
||||||
|
@ -223,19 +199,13 @@ pub async fn get_state_events_for_empty_key_route(
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn send_state_event_for_key_helper(
|
async fn send_state_event_for_key_helper(
|
||||||
sender: &UserId,
|
sender: &UserId, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, state_key: String,
|
||||||
room_id: &RoomId,
|
|
||||||
event_type: &StateEventType,
|
|
||||||
json: &Raw<AnyStateEventContent>,
|
|
||||||
state_key: String,
|
|
||||||
) -> Result<Arc<EventId>> {
|
) -> Result<Arc<EventId>> {
|
||||||
let sender_user = sender;
|
let sender_user = sender;
|
||||||
|
|
||||||
// TODO: Review this check, error if event is unparsable, use event type, allow alias if it
|
// TODO: Review this check, error if event is unparsable, use event type, allow
|
||||||
// previously existed
|
// alias if it previously existed
|
||||||
if let Ok(canonical_alias) =
|
if let Ok(canonical_alias) = serde_json::from_str::<RoomCanonicalAliasEventContent>(json.json().get()) {
|
||||||
serde_json::from_str::<RoomCanonicalAliasEventContent>(json.json().get())
|
|
||||||
{
|
|
||||||
let mut aliases = canonical_alias.alt_aliases.clone();
|
let mut aliases = canonical_alias.alt_aliases.clone();
|
||||||
|
|
||||||
if let Some(alias) = canonical_alias.alias {
|
if let Some(alias) = canonical_alias.alias {
|
||||||
|
@ -253,22 +223,14 @@ async fn send_state_event_for_key_helper(
|
||||||
{
|
{
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::Forbidden,
|
ErrorKind::Forbidden,
|
||||||
"You are only allowed to send canonical_alias \
|
"You are only allowed to send canonical_alias events when it's aliases already exists",
|
||||||
events when it's aliases already exists",
|
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mutex_state = Arc::clone(
|
let mutex_state =
|
||||||
services()
|
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
|
||||||
.globals
|
|
||||||
.roomid_mutex_state
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.entry(room_id.to_owned())
|
|
||||||
.or_default(),
|
|
||||||
);
|
|
||||||
let state_lock = mutex_state.lock().await;
|
let state_lock = mutex_state.lock().await;
|
||||||
|
|
||||||
let event_id = services()
|
let event_id = services()
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,4 +1,5 @@
|
||||||
use crate::{services, Error, Result, Ruma};
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::tag::{create_tag, delete_tag, get_tags},
|
api::client::tag::{create_tag, delete_tag, get_tags},
|
||||||
events::{
|
events::{
|
||||||
|
@ -6,29 +7,21 @@ use ruma::{
|
||||||
RoomAccountDataEventType,
|
RoomAccountDataEventType,
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
use std::collections::BTreeMap;
|
|
||||||
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}`
|
/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}`
|
||||||
///
|
///
|
||||||
/// Adds a tag to the room.
|
/// Adds a tag to the room.
|
||||||
///
|
///
|
||||||
/// - Inserts the tag into the tag event of the room account data.
|
/// - Inserts the tag into the tag event of the room account data.
|
||||||
pub async fn update_tag_route(
|
pub async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Result<create_tag::v3::Response> {
|
||||||
body: Ruma<create_tag::v3::Request>,
|
|
||||||
) -> Result<create_tag::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let event = services().account_data.get(
|
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
|
||||||
Some(&body.room_id),
|
|
||||||
sender_user,
|
|
||||||
RoomAccountDataEventType::Tag,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut tags_event = event
|
let mut tags_event = event
|
||||||
.map(|e| {
|
.map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
|
||||||
serde_json::from_str(e.get())
|
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
|
||||||
})
|
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
Ok(TagEvent {
|
Ok(TagEvent {
|
||||||
content: TagEventContent {
|
content: TagEventContent {
|
||||||
|
@ -37,10 +30,7 @@ pub async fn update_tag_route(
|
||||||
})
|
})
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
tags_event
|
tags_event.content.tags.insert(body.tag.clone().into(), body.tag_info.clone());
|
||||||
.content
|
|
||||||
.tags
|
|
||||||
.insert(body.tag.clone().into(), body.tag_info.clone());
|
|
||||||
|
|
||||||
services().account_data.update(
|
services().account_data.update(
|
||||||
Some(&body.room_id),
|
Some(&body.room_id),
|
||||||
|
@ -57,22 +47,13 @@ pub async fn update_tag_route(
|
||||||
/// Deletes a tag from the room.
|
/// Deletes a tag from the room.
|
||||||
///
|
///
|
||||||
/// - Removes the tag from the tag event of the room account data.
|
/// - Removes the tag from the tag event of the room account data.
|
||||||
pub async fn delete_tag_route(
|
pub async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Result<delete_tag::v3::Response> {
|
||||||
body: Ruma<delete_tag::v3::Request>,
|
|
||||||
) -> Result<delete_tag::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let event = services().account_data.get(
|
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
|
||||||
Some(&body.room_id),
|
|
||||||
sender_user,
|
|
||||||
RoomAccountDataEventType::Tag,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let mut tags_event = event
|
let mut tags_event = event
|
||||||
.map(|e| {
|
.map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
|
||||||
serde_json::from_str(e.get())
|
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
|
||||||
})
|
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
Ok(TagEvent {
|
Ok(TagEvent {
|
||||||
content: TagEventContent {
|
content: TagEventContent {
|
||||||
|
@ -101,17 +82,10 @@ pub async fn delete_tag_route(
|
||||||
pub async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> {
|
pub async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> {
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
let event = services().account_data.get(
|
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
|
||||||
Some(&body.room_id),
|
|
||||||
sender_user,
|
|
||||||
RoomAccountDataEventType::Tag,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let tags_event = event
|
let tags_event = event
|
||||||
.map(|e| {
|
.map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
|
||||||
serde_json::from_str(e.get())
|
|
||||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
|
||||||
})
|
|
||||||
.unwrap_or_else(|| {
|
.unwrap_or_else(|| {
|
||||||
Ok(TagEvent {
|
Ok(TagEvent {
|
||||||
content: TagEventContent {
|
content: TagEventContent {
|
||||||
|
|
|
@ -1,14 +1,13 @@
|
||||||
use crate::{Result, Ruma};
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use ruma::api::client::thirdparty::get_protocols;
|
use ruma::api::client::thirdparty::get_protocols;
|
||||||
|
|
||||||
use std::collections::BTreeMap;
|
use crate::{Result, Ruma};
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/thirdparty/protocols`
|
/// # `GET /_matrix/client/r0/thirdparty/protocols`
|
||||||
///
|
///
|
||||||
/// TODO: Fetches all metadata about protocols supported by the homeserver.
|
/// TODO: Fetches all metadata about protocols supported by the homeserver.
|
||||||
pub async fn get_protocols_route(
|
pub async fn get_protocols_route(_body: Ruma<get_protocols::v3::Request>) -> Result<get_protocols::v3::Response> {
|
||||||
_body: Ruma<get_protocols::v3::Request>,
|
|
||||||
) -> Result<get_protocols::v3::Response> {
|
|
||||||
// TODO
|
// TODO
|
||||||
Ok(get_protocols::v3::Response {
|
Ok(get_protocols::v3::Response {
|
||||||
protocols: BTreeMap::new(),
|
protocols: BTreeMap::new(),
|
||||||
|
|
|
@ -3,21 +3,14 @@ use ruma::api::client::{error::ErrorKind, threads::get_threads};
|
||||||
use crate::{services, Error, Result, Ruma};
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/threads`
|
/// # `GET /_matrix/client/r0/rooms/{roomId}/threads`
|
||||||
pub async fn get_threads_route(
|
pub async fn get_threads_route(body: Ruma<get_threads::v1::Request>) -> Result<get_threads::v1::Response> {
|
||||||
body: Ruma<get_threads::v1::Request>,
|
|
||||||
) -> Result<get_threads::v1::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
// Use limit or else 10, with maximum 100
|
// Use limit or else 10, with maximum 100
|
||||||
let limit = body
|
let limit = body.limit.and_then(|l| l.try_into().ok()).unwrap_or(10).min(100);
|
||||||
.limit
|
|
||||||
.and_then(|l| l.try_into().ok())
|
|
||||||
.unwrap_or(10)
|
|
||||||
.min(100);
|
|
||||||
|
|
||||||
let from = if let Some(from) = &body.from {
|
let from = if let Some(from) = &body.from {
|
||||||
from.parse()
|
from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))?
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))?
|
|
||||||
} else {
|
} else {
|
||||||
u64::MAX
|
u64::MAX
|
||||||
};
|
};
|
||||||
|
@ -40,10 +33,7 @@ pub async fn get_threads_route(
|
||||||
let next_batch = threads.last().map(|(count, _)| count.to_string());
|
let next_batch = threads.last().map(|(count, _)| count.to_string());
|
||||||
|
|
||||||
Ok(get_threads::v1::Response {
|
Ok(get_threads::v1::Response {
|
||||||
chunk: threads
|
chunk: threads.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(),
|
||||||
.into_iter()
|
|
||||||
.map(|(_, pdu)| pdu.to_room_event())
|
|
||||||
.collect(),
|
|
||||||
next_batch,
|
next_batch,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use crate::{services, Error, Result, Ruma};
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::{
|
api::{
|
||||||
client::{error::ErrorKind, to_device::send_event_to_device},
|
client::{error::ErrorKind, to_device::send_event_to_device},
|
||||||
|
@ -9,6 +8,8 @@ use ruma::{
|
||||||
to_device::DeviceIdOrAllDevices,
|
to_device::DeviceIdOrAllDevices,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}`
|
/// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}`
|
||||||
///
|
///
|
||||||
/// Send a to-device event to a set of client devices.
|
/// Send a to-device event to a set of client devices.
|
||||||
|
@ -19,11 +20,7 @@ pub async fn send_event_to_device_route(
|
||||||
let sender_device = body.sender_device.as_deref();
|
let sender_device = body.sender_device.as_deref();
|
||||||
|
|
||||||
// Check if this is a new transaction id
|
// Check if this is a new transaction id
|
||||||
if services()
|
if services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)?.is_some() {
|
||||||
.transaction_ids
|
|
||||||
.existing_txnid(sender_user, sender_device, &body.txn_id)?
|
|
||||||
.is_some()
|
|
||||||
{
|
|
||||||
return Ok(send_event_to_device::v3::Response {});
|
return Ok(send_event_to_device::v3::Response {});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -38,14 +35,12 @@ pub async fn send_event_to_device_route(
|
||||||
|
|
||||||
services().sending.send_reliable_edu(
|
services().sending.send_reliable_edu(
|
||||||
target_user_id.server_name(),
|
target_user_id.server_name(),
|
||||||
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(
|
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent {
|
||||||
DirectDeviceContent {
|
|
||||||
sender: sender_user.clone(),
|
sender: sender_user.clone(),
|
||||||
ev_type: body.event_type.clone(),
|
ev_type: body.event_type.clone(),
|
||||||
message_id: count.to_string().into(),
|
message_id: count.to_string().into(),
|
||||||
messages,
|
messages,
|
||||||
},
|
}))
|
||||||
))
|
|
||||||
.expect("DirectToDevice EDU can be serialized"),
|
.expect("DirectToDevice EDU can be serialized"),
|
||||||
count,
|
count,
|
||||||
)?;
|
)?;
|
||||||
|
@ -60,11 +55,11 @@ pub async fn send_event_to_device_route(
|
||||||
target_user_id,
|
target_user_id,
|
||||||
target_device_id,
|
target_device_id,
|
||||||
&body.event_type.to_string(),
|
&body.event_type.to_string(),
|
||||||
event.deserialize_as().map_err(|_| {
|
event
|
||||||
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
|
.deserialize_as()
|
||||||
})?,
|
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
|
||||||
)?;
|
)?;
|
||||||
}
|
},
|
||||||
|
|
||||||
DeviceIdOrAllDevices::AllDevices => {
|
DeviceIdOrAllDevices::AllDevices => {
|
||||||
for target_device_id in services().users.all_device_ids(target_user_id) {
|
for target_device_id in services().users.all_device_ids(target_user_id) {
|
||||||
|
@ -73,20 +68,18 @@ pub async fn send_event_to_device_route(
|
||||||
target_user_id,
|
target_user_id,
|
||||||
&target_device_id?,
|
&target_device_id?,
|
||||||
&body.event_type.to_string(),
|
&body.event_type.to_string(),
|
||||||
event.deserialize_as().map_err(|_| {
|
event
|
||||||
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
|
.deserialize_as()
|
||||||
})?,
|
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Save transaction id with empty data
|
// Save transaction id with empty data
|
||||||
services()
|
services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
|
||||||
.transaction_ids
|
|
||||||
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
|
|
||||||
|
|
||||||
Ok(send_event_to_device::v3::Response {})
|
Ok(send_event_to_device::v3::Response {})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::{services, utils, Error, Result, Ruma};
|
|
||||||
use ruma::api::client::{error::ErrorKind, typing::create_typing_event};
|
use ruma::api::client::{error::ErrorKind, typing::create_typing_event};
|
||||||
|
|
||||||
|
use crate::{services, utils, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}`
|
/// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}`
|
||||||
///
|
///
|
||||||
/// Sets the typing state of the sender user.
|
/// Sets the typing state of the sender user.
|
||||||
|
@ -11,15 +12,8 @@ pub async fn create_typing_event_route(
|
||||||
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
|
|
||||||
if !services()
|
if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? {
|
||||||
.rooms
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "You are not in this room."));
|
||||||
.state_cache
|
|
||||||
.is_joined(sender_user, &body.room_id)?
|
|
||||||
{
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"You are not in this room.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if let Typing::Yes(duration) = body.state {
|
if let Typing::Yes(duration) = body.state {
|
||||||
|
@ -29,11 +23,7 @@ pub async fn create_typing_event_route(
|
||||||
duration.as_millis() as u64 + utils::millis_since_unix_epoch(),
|
duration.as_millis() as u64 + utils::millis_since_unix_epoch(),
|
||||||
)?;
|
)?;
|
||||||
} else {
|
} else {
|
||||||
services()
|
services().rooms.edus.typing.typing_remove(sender_user, &body.room_id)?;
|
||||||
.rooms
|
|
||||||
.edus
|
|
||||||
.typing
|
|
||||||
.typing_remove(sender_user, &body.room_id)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(create_typing_event::v3::Response {})
|
Ok(create_typing_event::v3::Response {})
|
||||||
|
|
|
@ -7,14 +7,16 @@ use crate::{services, Error, Result, Ruma};
|
||||||
|
|
||||||
/// # `GET /_matrix/client/versions`
|
/// # `GET /_matrix/client/versions`
|
||||||
///
|
///
|
||||||
/// Get the versions of the specification and unstable features supported by this server.
|
/// Get the versions of the specification and unstable features supported by
|
||||||
|
/// this server.
|
||||||
///
|
///
|
||||||
/// - Versions take the form MAJOR.MINOR.PATCH
|
/// - Versions take the form MAJOR.MINOR.PATCH
|
||||||
/// - Only the latest PATCH release will be reported for each MAJOR.MINOR value
|
/// - Only the latest PATCH release will be reported for each MAJOR.MINOR value
|
||||||
/// - Unstable features are namespaced and may include version information in their name
|
/// - Unstable features are namespaced and may include version information in
|
||||||
|
/// their name
|
||||||
///
|
///
|
||||||
/// Note: Unstable features are used while developing new features. Clients should avoid using
|
/// Note: Unstable features are used while developing new features. Clients
|
||||||
/// unstable features in their stable releases
|
/// should avoid using unstable features in their stable releases
|
||||||
pub async fn get_supported_versions_route(
|
pub async fn get_supported_versions_route(
|
||||||
_body: Ruma<get_supported_versions::Request>,
|
_body: Ruma<get_supported_versions::Request>,
|
||||||
) -> Result<get_supported_versions::Response> {
|
) -> Result<get_supported_versions::Response> {
|
||||||
|
@ -60,8 +62,8 @@ pub async fn well_known_client_route() -> Result<impl IntoResponse> {
|
||||||
|
|
||||||
/// # `GET /client/server.json`
|
/// # `GET /client/server.json`
|
||||||
///
|
///
|
||||||
/// Endpoint provided by sliding sync proxy used by some clients such as Element Web
|
/// Endpoint provided by sliding sync proxy used by some clients such as Element
|
||||||
/// as a non-standard health check.
|
/// Web as a non-standard health check.
|
||||||
pub async fn syncv3_client_server_json() -> Result<impl IntoResponse> {
|
pub async fn syncv3_client_server_json() -> Result<impl IntoResponse> {
|
||||||
let server_url = match services().globals.well_known_client() {
|
let server_url = match services().globals.well_known_client() {
|
||||||
Some(url) => url.clone(),
|
Some(url) => url.clone(),
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
use crate::{services, Result, Ruma};
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::user_directory::search_users,
|
api::client::user_directory::search_users,
|
||||||
events::{
|
events::{
|
||||||
|
@ -7,15 +6,16 @@ use ruma::{
|
||||||
},
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::{services, Result, Ruma};
|
||||||
|
|
||||||
/// # `POST /_matrix/client/r0/user_directory/search`
|
/// # `POST /_matrix/client/r0/user_directory/search`
|
||||||
///
|
///
|
||||||
/// Searches all known users for a match.
|
/// Searches all known users for a match.
|
||||||
///
|
///
|
||||||
/// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public)
|
/// - Hides any local users that aren't in any public rooms (i.e. those that
|
||||||
|
/// have the join rule set to public)
|
||||||
/// and don't share a room with the sender
|
/// and don't share a room with the sender
|
||||||
pub async fn search_users_route(
|
pub async fn search_users_route(body: Ruma<search_users::v3::Request>) -> Result<search_users::v3::Response> {
|
||||||
body: Ruma<search_users::v3::Request>,
|
|
||||||
) -> Result<search_users::v3::Response> {
|
|
||||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||||
let limit = u64::from(body.limit) as usize;
|
let limit = u64::from(body.limit) as usize;
|
||||||
|
|
||||||
|
@ -29,56 +29,37 @@ pub async fn search_users_route(
|
||||||
avatar_url: services().users.avatar_url(&user_id).ok()?,
|
avatar_url: services().users.avatar_url(&user_id).ok()?,
|
||||||
};
|
};
|
||||||
|
|
||||||
let user_id_matches = user
|
let user_id_matches = user.user_id.to_string().to_lowercase().contains(&body.search_term.to_lowercase());
|
||||||
.user_id
|
|
||||||
.to_string()
|
|
||||||
.to_lowercase()
|
|
||||||
.contains(&body.search_term.to_lowercase());
|
|
||||||
|
|
||||||
let user_displayname_matches = user
|
let user_displayname_matches = user
|
||||||
.display_name
|
.display_name
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.filter(|name| {
|
.filter(|name| name.to_lowercase().contains(&body.search_term.to_lowercase()))
|
||||||
name.to_lowercase()
|
|
||||||
.contains(&body.search_term.to_lowercase())
|
|
||||||
})
|
|
||||||
.is_some();
|
.is_some();
|
||||||
|
|
||||||
if !user_id_matches && !user_displayname_matches {
|
if !user_id_matches && !user_displayname_matches {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
|
|
||||||
let user_is_in_public_rooms = services()
|
let user_is_in_public_rooms =
|
||||||
.rooms
|
services().rooms.state_cache.rooms_joined(&user_id).filter_map(std::result::Result::ok).any(|room| {
|
||||||
.state_cache
|
services().rooms.state_accessor.room_state_get(&room, &StateEventType::RoomJoinRules, "").map_or(
|
||||||
.rooms_joined(&user_id)
|
false,
|
||||||
.filter_map(std::result::Result::ok)
|
|event| {
|
||||||
.any(|room| {
|
|
||||||
services()
|
|
||||||
.rooms
|
|
||||||
.state_accessor
|
|
||||||
.room_state_get(&room, &StateEventType::RoomJoinRules, "")
|
|
||||||
.map_or(false, |event| {
|
|
||||||
event.map_or(false, |event| {
|
event.map_or(false, |event| {
|
||||||
serde_json::from_str(event.content.get())
|
serde_json::from_str(event.content.get())
|
||||||
.map_or(false, |r: RoomJoinRulesEventContent| {
|
.map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public)
|
||||||
r.join_rule == JoinRule::Public
|
|
||||||
})
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
|
},
|
||||||
|
)
|
||||||
});
|
});
|
||||||
|
|
||||||
if user_is_in_public_rooms {
|
if user_is_in_public_rooms {
|
||||||
return Some(user);
|
return Some(user);
|
||||||
}
|
}
|
||||||
|
|
||||||
let user_is_in_shared_rooms = services()
|
let user_is_in_shared_rooms =
|
||||||
.rooms
|
services().rooms.user.get_shared_rooms(vec![sender_user.clone(), user_id]).ok()?.next().is_some();
|
||||||
.user
|
|
||||||
.get_shared_rooms(vec![sender_user.clone(), user_id])
|
|
||||||
.ok()?
|
|
||||||
.next()
|
|
||||||
.is_some();
|
|
||||||
|
|
||||||
if user_is_in_shared_rooms {
|
if user_is_in_shared_rooms {
|
||||||
return Some(user);
|
return Some(user);
|
||||||
|
@ -90,5 +71,8 @@ pub async fn search_users_route(
|
||||||
let results = users.by_ref().take(limit).collect();
|
let results = users.by_ref().take(limit).collect();
|
||||||
let limited = users.next().is_some();
|
let limited = users.next().is_some();
|
||||||
|
|
||||||
Ok(search_users::v3::Response { results, limited })
|
Ok(search_users::v3::Response {
|
||||||
|
results,
|
||||||
|
limited,
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,11 @@
|
||||||
use crate::{services, Result, Ruma};
|
use std::time::{Duration, SystemTime};
|
||||||
|
|
||||||
use base64::{engine::general_purpose, Engine as _};
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
use hmac::{Hmac, Mac};
|
use hmac::{Hmac, Mac};
|
||||||
use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch};
|
use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch};
|
||||||
use sha1::Sha1;
|
use sha1::Sha1;
|
||||||
use std::time::{Duration, SystemTime};
|
|
||||||
|
use crate::{services, Result, Ruma};
|
||||||
|
|
||||||
type HmacSha1 = Hmac<Sha1>;
|
type HmacSha1 = Hmac<Sha1>;
|
||||||
|
|
||||||
|
@ -25,8 +27,7 @@ pub async fn turn_server_route(
|
||||||
|
|
||||||
let username: String = format!("{}:{}", expiry.get(), sender_user);
|
let username: String = format!("{}:{}", expiry.get(), sender_user);
|
||||||
|
|
||||||
let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes())
|
let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes()).expect("HMAC can take key of any size");
|
||||||
.expect("HMAC can take key of any size");
|
|
||||||
mac.update(username.as_bytes());
|
mac.update(username.as_bytes());
|
||||||
|
|
||||||
let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
|
let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
|
||||||
|
|
|
@ -43,18 +43,16 @@ where
|
||||||
let (mut parts, mut body) = match req.with_limited_body() {
|
let (mut parts, mut body) = match req.with_limited_body() {
|
||||||
Ok(limited_req) => {
|
Ok(limited_req) => {
|
||||||
let (parts, body) = limited_req.into_parts();
|
let (parts, body) = limited_req.into_parts();
|
||||||
let body = to_bytes(body)
|
let body =
|
||||||
.await
|
to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
|
||||||
(parts, body)
|
(parts, body)
|
||||||
}
|
},
|
||||||
Err(original_req) => {
|
Err(original_req) => {
|
||||||
let (parts, body) = original_req.into_parts();
|
let (parts, body) = original_req.into_parts();
|
||||||
let body = to_bytes(body)
|
let body =
|
||||||
.await
|
to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
|
||||||
(parts, body)
|
(parts, body)
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let metadata = T::METADATA;
|
let metadata = T::METADATA;
|
||||||
|
@ -66,11 +64,8 @@ where
|
||||||
Ok(params) => params,
|
Ok(params) => params,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!(%query, "Failed to deserialize query parameters: {}", e);
|
error!(%query, "Failed to deserialize query parameters: {}", e);
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters"));
|
||||||
ErrorKind::Unknown,
|
},
|
||||||
"Failed to read query parameters",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let token = match &auth_header {
|
let token = match &auth_header {
|
||||||
|
@ -81,12 +76,12 @@ where
|
||||||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||||
|
|
||||||
let appservices = services().appservice.all().unwrap();
|
let appservices = services().appservice.all().unwrap();
|
||||||
let appservice_registration = appservices
|
let appservice_registration =
|
||||||
.iter()
|
appservices.iter().find(|(_id, registration)| Some(registration.as_token.as_str()) == token);
|
||||||
.find(|(_id, registration)| Some(registration.as_token.as_str()) == token);
|
|
||||||
|
|
||||||
let (sender_user, sender_device, sender_servername, from_appservice) =
|
let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) =
|
||||||
if let Some((_id, registration)) = appservice_registration {
|
appservice_registration
|
||||||
|
{
|
||||||
match metadata.authentication {
|
match metadata.authentication {
|
||||||
AuthScheme::AccessToken => {
|
AuthScheme::AccessToken => {
|
||||||
let user_id = query_params.user_id.map_or_else(
|
let user_id = query_params.user_id.map_or_else(
|
||||||
|
@ -101,15 +96,12 @@ where
|
||||||
);
|
);
|
||||||
|
|
||||||
if !services().users.exists(&user_id).unwrap() {
|
if !services().users.exists(&user_id).unwrap() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "User does not exist."));
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"User does not exist.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Check if appservice is allowed to be that user
|
// TODO: Check if appservice is allowed to be that user
|
||||||
(Some(user_id), None, None, true)
|
(Some(user_id), None, None, true)
|
||||||
}
|
},
|
||||||
AuthScheme::ServerSignatures => (None, None, None, true),
|
AuthScheme::ServerSignatures => (None, None, None, true),
|
||||||
AuthScheme::None => (None, None, None, true),
|
AuthScheme::None => (None, None, None, true),
|
||||||
}
|
}
|
||||||
|
@ -118,92 +110,62 @@ where
|
||||||
AuthScheme::AccessToken => {
|
AuthScheme::AccessToken => {
|
||||||
let token = match token {
|
let token = match token {
|
||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
_ => {
|
_ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::MissingToken,
|
|
||||||
"Missing access token.",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
match services().users.find_from_token(token).unwrap() {
|
match services().users.find_from_token(token).unwrap() {
|
||||||
None => {
|
None => {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::UnknownToken { soft_logout: false },
|
ErrorKind::UnknownToken {
|
||||||
|
soft_logout: false,
|
||||||
|
},
|
||||||
"Unknown access token.",
|
"Unknown access token.",
|
||||||
))
|
))
|
||||||
|
},
|
||||||
|
Some((user_id, device_id)) => {
|
||||||
|
(Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
|
||||||
|
},
|
||||||
}
|
}
|
||||||
Some((user_id, device_id)) => (
|
},
|
||||||
Some(user_id),
|
|
||||||
Some(OwnedDeviceId::from(device_id)),
|
|
||||||
None,
|
|
||||||
false,
|
|
||||||
),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
AuthScheme::ServerSignatures => {
|
AuthScheme::ServerSignatures => {
|
||||||
let TypedHeader(Authorization(x_matrix)) = parts
|
let TypedHeader(Authorization(x_matrix)) =
|
||||||
.extract::<TypedHeader<Authorization<XMatrix>>>()
|
parts.extract::<TypedHeader<Authorization<XMatrix>>>().await.map_err(|e| {
|
||||||
.await
|
|
||||||
.map_err(|e| {
|
|
||||||
warn!("Missing or invalid Authorization header: {}", e);
|
warn!("Missing or invalid Authorization header: {}", e);
|
||||||
|
|
||||||
let msg = match e.reason() {
|
let msg = match e.reason() {
|
||||||
TypedHeaderRejectionReason::Missing => {
|
TypedHeaderRejectionReason::Missing => "Missing Authorization header.",
|
||||||
"Missing Authorization header."
|
TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.",
|
||||||
}
|
|
||||||
TypedHeaderRejectionReason::Error(_) => {
|
|
||||||
"Invalid X-Matrix signatures."
|
|
||||||
}
|
|
||||||
_ => "Unknown header-related error",
|
_ => "Unknown header-related error",
|
||||||
};
|
};
|
||||||
|
|
||||||
Error::BadRequest(ErrorKind::Forbidden, msg)
|
Error::BadRequest(ErrorKind::Forbidden, msg)
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
let origin_signatures = BTreeMap::from_iter([(
|
let origin_signatures =
|
||||||
x_matrix.key.clone(),
|
BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig))]);
|
||||||
CanonicalJsonValue::String(x_matrix.sig),
|
|
||||||
)]);
|
|
||||||
|
|
||||||
let signatures = BTreeMap::from_iter([(
|
let signatures = BTreeMap::from_iter([(
|
||||||
x_matrix.origin.as_str().to_owned(),
|
x_matrix.origin.as_str().to_owned(),
|
||||||
CanonicalJsonValue::Object(origin_signatures),
|
CanonicalJsonValue::Object(origin_signatures),
|
||||||
)]);
|
)]);
|
||||||
|
|
||||||
let server_destination =
|
let server_destination = services().globals.server_name().as_str().to_owned();
|
||||||
services().globals.server_name().as_str().to_owned();
|
|
||||||
|
|
||||||
if let Some(destination) = x_matrix.destination.as_ref() {
|
if let Some(destination) = x_matrix.destination.as_ref() {
|
||||||
if destination != &server_destination {
|
if destination != &server_destination {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Invalid authorization."));
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"Invalid authorization.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut request_map = BTreeMap::from_iter([
|
let mut request_map = BTreeMap::from_iter([
|
||||||
(
|
("method".to_owned(), CanonicalJsonValue::String(parts.method.to_string())),
|
||||||
"method".to_owned(),
|
("uri".to_owned(), CanonicalJsonValue::String(parts.uri.to_string())),
|
||||||
CanonicalJsonValue::String(parts.method.to_string()),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"uri".to_owned(),
|
|
||||||
CanonicalJsonValue::String(parts.uri.to_string()),
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
"origin".to_owned(),
|
"origin".to_owned(),
|
||||||
CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()),
|
CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()),
|
||||||
),
|
),
|
||||||
(
|
("destination".to_owned(), CanonicalJsonValue::String(server_destination)),
|
||||||
"destination".to_owned(),
|
("signatures".to_owned(), CanonicalJsonValue::Object(signatures)),
|
||||||
CanonicalJsonValue::String(server_destination),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"signatures".to_owned(),
|
|
||||||
CanonicalJsonValue::Object(signatures),
|
|
||||||
),
|
|
||||||
]);
|
]);
|
||||||
|
|
||||||
if let Some(json_body) = &json_body {
|
if let Some(json_body) = &json_body {
|
||||||
|
@ -213,25 +175,18 @@ where
|
||||||
let keys_result = services()
|
let keys_result = services()
|
||||||
.rooms
|
.rooms
|
||||||
.event_handler
|
.event_handler
|
||||||
.fetch_signing_keys_for_server(
|
.fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.clone()])
|
||||||
&x_matrix.origin,
|
|
||||||
vec![x_matrix.key.clone()],
|
|
||||||
)
|
|
||||||
.await;
|
.await;
|
||||||
|
|
||||||
let keys = match keys_result {
|
let keys = match keys_result {
|
||||||
Ok(b) => b,
|
Ok(b) => b,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Failed to fetch signing keys: {}", e);
|
warn!("Failed to fetch signing keys: {}", e);
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Failed to fetch signing keys."));
|
||||||
ErrorKind::Forbidden,
|
},
|
||||||
"Failed to fetch signing keys.",
|
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let pub_key_map =
|
let pub_key_map = BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]);
|
||||||
BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]);
|
|
||||||
|
|
||||||
match ruma::signatures::verify_json(&pub_key_map, &request_map) {
|
match ruma::signatures::verify_json(&pub_key_map, &request_map) {
|
||||||
Ok(()) => (None, None, Some(x_matrix.origin), false),
|
Ok(()) => (None, None, Some(x_matrix.origin), false),
|
||||||
|
@ -243,9 +198,8 @@ where
|
||||||
|
|
||||||
if parts.uri.to_string().contains('@') {
|
if parts.uri.to_string().contains('@') {
|
||||||
warn!(
|
warn!(
|
||||||
"Request uri contained '@' character. Make sure your \
|
"Request uri contained '@' character. Make sure your reverse proxy gives Conduit \
|
||||||
reverse proxy gives Conduit the raw uri (apache: use \
|
the raw uri (apache: use nocanon)"
|
||||||
nocanon)"
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -253,45 +207,35 @@ where
|
||||||
ErrorKind::Forbidden,
|
ErrorKind::Forbidden,
|
||||||
"Failed to verify X-Matrix signatures.",
|
"Failed to verify X-Matrix signatures.",
|
||||||
));
|
));
|
||||||
|
},
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
}
|
|
||||||
AuthScheme::None => match parts.uri.path() {
|
AuthScheme::None => match parts.uri.path() {
|
||||||
// allow_public_room_directory_without_auth
|
// allow_public_room_directory_without_auth
|
||||||
"/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => {
|
"/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => {
|
||||||
if !services()
|
if !services().globals.config.allow_public_room_directory_without_auth {
|
||||||
.globals
|
|
||||||
.config
|
|
||||||
.allow_public_room_directory_without_auth
|
|
||||||
{
|
|
||||||
let token = match token {
|
let token = match token {
|
||||||
Some(token) => token,
|
Some(token) => token,
|
||||||
_ => {
|
_ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::MissingToken,
|
|
||||||
"Missing access token.",
|
|
||||||
))
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
match services().users.find_from_token(token).unwrap() {
|
match services().users.find_from_token(token).unwrap() {
|
||||||
None => {
|
None => {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::UnknownToken { soft_logout: false },
|
ErrorKind::UnknownToken {
|
||||||
|
soft_logout: false,
|
||||||
|
},
|
||||||
"Unknown access token.",
|
"Unknown access token.",
|
||||||
))
|
))
|
||||||
}
|
},
|
||||||
Some((user_id, device_id)) => (
|
Some((user_id, device_id)) => {
|
||||||
Some(user_id),
|
(Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
|
||||||
Some(OwnedDeviceId::from(device_id)),
|
},
|
||||||
None,
|
|
||||||
false,
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
(None, None, None, false)
|
(None, None, None, false)
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
_ => (None, None, None, false),
|
_ => (None, None, None, false),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
@ -302,8 +246,7 @@ where
|
||||||
|
|
||||||
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
||||||
let user_id = sender_user.clone().unwrap_or_else(|| {
|
let user_id = sender_user.clone().unwrap_or_else(|| {
|
||||||
UserId::parse_with_server_name("", services().globals.server_name())
|
UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid")
|
||||||
.expect("we know this is valid")
|
|
||||||
});
|
});
|
||||||
|
|
||||||
let uiaa_request = json_body
|
let uiaa_request = json_body
|
||||||
|
@ -367,9 +310,7 @@ impl Credentials for XMatrix {
|
||||||
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
|
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
|
||||||
);
|
);
|
||||||
|
|
||||||
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..])
|
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]).ok()?.trim_start();
|
||||||
.ok()?
|
|
||||||
.trim_start();
|
|
||||||
|
|
||||||
let mut origin = None;
|
let mut origin = None;
|
||||||
let mut destination = None;
|
let mut destination = None;
|
||||||
|
@ -381,10 +322,7 @@ impl Credentials for XMatrix {
|
||||||
|
|
||||||
// It's not at all clear why some fields are quoted and others not in the spec,
|
// It's not at all clear why some fields are quoted and others not in the spec,
|
||||||
// let's simply accept either form for every field.
|
// let's simply accept either form for every field.
|
||||||
let value = value
|
let value = value.strip_prefix('"').and_then(|rest| rest.strip_suffix('"')).unwrap_or(value);
|
||||||
.strip_prefix('"')
|
|
||||||
.and_then(|rest| rest.strip_suffix('"'))
|
|
||||||
.unwrap_or(value);
|
|
||||||
|
|
||||||
// FIXME: Catch multiple fields of the same name
|
// FIXME: Catch multiple fields of the same name
|
||||||
match name {
|
match name {
|
||||||
|
@ -392,10 +330,7 @@ impl Credentials for XMatrix {
|
||||||
"key" => key = Some(value.to_owned()),
|
"key" => key = Some(value.to_owned()),
|
||||||
"sig" => sig = Some(value.to_owned()),
|
"sig" => sig = Some(value.to_owned()),
|
||||||
"destination" => destination = Some(value.to_owned()),
|
"destination" => destination = Some(value.to_owned()),
|
||||||
_ => debug!(
|
_ => debug!("Unexpected field `{}` in X-Matrix Authorization header", name),
|
||||||
"Unexpected field `{}` in X-Matrix Authorization header",
|
|
||||||
name
|
|
||||||
),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -407,9 +342,7 @@ impl Credentials for XMatrix {
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn encode(&self) -> http::HeaderValue {
|
fn encode(&self) -> http::HeaderValue { todo!() }
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
|
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
|
||||||
|
|
|
@ -1,10 +1,9 @@
|
||||||
use crate::Error;
|
|
||||||
use ruma::{
|
|
||||||
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName,
|
|
||||||
OwnedUserId,
|
|
||||||
};
|
|
||||||
use std::ops::Deref;
|
use std::ops::Deref;
|
||||||
|
|
||||||
|
use ruma::{api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId};
|
||||||
|
|
||||||
|
use crate::Error;
|
||||||
|
|
||||||
#[cfg(feature = "conduit_bin")]
|
#[cfg(feature = "conduit_bin")]
|
||||||
mod axum;
|
mod axum;
|
||||||
|
|
||||||
|
@ -22,22 +21,16 @@ pub struct Ruma<T> {
|
||||||
impl<T> Deref for Ruma<T> {
|
impl<T> Deref for Ruma<T> {
|
||||||
type Target = T;
|
type Target = T;
|
||||||
|
|
||||||
fn deref(&self) -> &Self::Target {
|
fn deref(&self) -> &Self::Target { &self.body }
|
||||||
&self.body
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[derive(Clone)]
|
#[derive(Clone)]
|
||||||
pub struct RumaResponse<T>(pub T);
|
pub struct RumaResponse<T>(pub T);
|
||||||
|
|
||||||
impl<T> From<T> for RumaResponse<T> {
|
impl<T> From<T> for RumaResponse<T> {
|
||||||
fn from(t: T) -> Self {
|
fn from(t: T) -> Self { Self(t) }
|
||||||
Self(t)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl From<Error> for RumaResponse<UiaaResponse> {
|
impl From<Error> for RumaResponse<UiaaResponse> {
|
||||||
fn from(t: Error) -> Self {
|
fn from(t: Error) -> Self { t.to_response() }
|
||||||
t.to_response()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -178,50 +178,51 @@ pub struct TlsConfig {
|
||||||
pub key: String,
|
pub key: String,
|
||||||
#[serde(default)]
|
#[serde(default)]
|
||||||
/// Whether to listen and allow for HTTP and HTTPS connections (insecure!)
|
/// Whether to listen and allow for HTTP and HTTPS connections (insecure!)
|
||||||
/// Only works / does something if the `axum_dual_protocol` feature flag was built
|
/// Only works / does something if the `axum_dual_protocol` feature flag was
|
||||||
|
/// built
|
||||||
pub dual_protocol: bool,
|
pub dual_protocol: bool,
|
||||||
}
|
}
|
||||||
|
|
||||||
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
||||||
|
|
||||||
impl Config {
|
impl Config {
|
||||||
/// Iterates over all the keys in the config file and warns if there is a deprecated key specified
|
/// Iterates over all the keys in the config file and warns if there is a
|
||||||
|
/// deprecated key specified
|
||||||
pub fn warn_deprecated(&self) {
|
pub fn warn_deprecated(&self) {
|
||||||
debug!("Checking for deprecated config keys");
|
debug!("Checking for deprecated config keys");
|
||||||
let mut was_deprecated = false;
|
let mut was_deprecated = false;
|
||||||
for key in self
|
for key in self.catchall.keys().filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) {
|
||||||
.catchall
|
|
||||||
.keys()
|
|
||||||
.filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key))
|
|
||||||
{
|
|
||||||
warn!("Config parameter \"{}\" is deprecated, ignoring.", key);
|
warn!("Config parameter \"{}\" is deprecated, ignoring.", key);
|
||||||
was_deprecated = true;
|
was_deprecated = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if was_deprecated {
|
if was_deprecated {
|
||||||
warn!("Read conduit documentation and check your configuration if any new configuration parameters should be adjusted");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
/// iterates over all the catchall keys (unknown config options) and warns if there are any.
|
|
||||||
pub fn warn_unknown_key(&self) {
|
|
||||||
debug!("Checking for unknown config keys");
|
|
||||||
for key in self.catchall.keys().filter(
|
|
||||||
|key| "config".to_owned().ne(key.to_owned()), /* "config" is expected */
|
|
||||||
) {
|
|
||||||
warn!(
|
warn!(
|
||||||
"Config parameter \"{}\" is unknown to conduwuit, ignoring.",
|
"Read conduit documentation and check your configuration if any new configuration parameters should \
|
||||||
key
|
be adjusted"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Checks the presence of the `address` and `unix_socket_path` keys in the raw_config, exiting the process if both keys were detected.
|
/// iterates over all the catchall keys (unknown config options) and warns
|
||||||
|
/// if there are any.
|
||||||
|
pub fn warn_unknown_key(&self) {
|
||||||
|
debug!("Checking for unknown config keys");
|
||||||
|
for key in
|
||||||
|
self.catchall.keys().filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */)
|
||||||
|
{
|
||||||
|
warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
/// Checks the presence of the `address` and `unix_socket_path` keys in the
|
||||||
|
/// raw_config, exiting the process if both keys were detected.
|
||||||
pub fn is_dual_listening(&self, raw_config: Figment) -> bool {
|
pub fn is_dual_listening(&self, raw_config: Figment) -> bool {
|
||||||
let check_address = raw_config.find_value("address");
|
let check_address = raw_config.find_value("address");
|
||||||
let check_unix_socket = raw_config.find_value("unix_socket_path");
|
let check_unix_socket = raw_config.find_value("unix_socket_path");
|
||||||
|
|
||||||
// are the check_address and check_unix_socket keys both Ok (specified) at the same time?
|
// are the check_address and check_unix_socket keys both Ok (specified) at the
|
||||||
|
// same time?
|
||||||
if check_address.is_ok() && check_unix_socket.is_ok() {
|
if check_address.is_ok() && check_unix_socket.is_ok() {
|
||||||
error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option.");
|
error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option.");
|
||||||
return true;
|
return true;
|
||||||
|
@ -238,28 +239,13 @@ impl fmt::Display for Config {
|
||||||
("Server name", self.server_name.host()),
|
("Server name", self.server_name.host()),
|
||||||
("Database backend", &self.database_backend),
|
("Database backend", &self.database_backend),
|
||||||
("Database path", &self.database_path),
|
("Database path", &self.database_path),
|
||||||
(
|
("Database cache capacity (MB)", &self.db_cache_capacity_mb.to_string()),
|
||||||
"Database cache capacity (MB)",
|
("Cache capacity modifier", &self.conduit_cache_capacity_modifier.to_string()),
|
||||||
&self.db_cache_capacity_mb.to_string(),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"Cache capacity modifier",
|
|
||||||
&self.conduit_cache_capacity_modifier.to_string(),
|
|
||||||
),
|
|
||||||
("PDU cache capacity", &self.pdu_cache_capacity.to_string()),
|
("PDU cache capacity", &self.pdu_cache_capacity.to_string()),
|
||||||
(
|
("Cleanup interval in seconds", &self.cleanup_second_interval.to_string()),
|
||||||
"Cleanup interval in seconds",
|
|
||||||
&self.cleanup_second_interval.to_string(),
|
|
||||||
),
|
|
||||||
("Maximum request size (bytes)", &self.max_request_size.to_string()),
|
("Maximum request size (bytes)", &self.max_request_size.to_string()),
|
||||||
(
|
("Maximum concurrent requests", &self.max_concurrent_requests.to_string()),
|
||||||
"Maximum concurrent requests",
|
("Allow registration", &self.allow_registration.to_string()),
|
||||||
&self.max_concurrent_requests.to_string(),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"Allow registration",
|
|
||||||
&self.allow_registration.to_string(),
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
"Registration token",
|
"Registration token",
|
||||||
match self.registration_token {
|
match self.registration_token {
|
||||||
|
@ -271,10 +257,7 @@ impl fmt::Display for Config {
|
||||||
"Allow guest registration (inherently false if allow registration is false)",
|
"Allow guest registration (inherently false if allow registration is false)",
|
||||||
&self.allow_guest_registration.to_string(),
|
&self.allow_guest_registration.to_string(),
|
||||||
),
|
),
|
||||||
(
|
("New user display name suffix", &self.new_user_displayname_suffix),
|
||||||
"New user display name suffix",
|
|
||||||
&self.new_user_displayname_suffix,
|
|
||||||
),
|
|
||||||
("Allow encryption", &self.allow_encryption.to_string()),
|
("Allow encryption", &self.allow_encryption.to_string()),
|
||||||
("Allow federation", &self.allow_federation.to_string()),
|
("Allow federation", &self.allow_federation.to_string()),
|
||||||
(
|
(
|
||||||
|
@ -293,10 +276,7 @@ impl fmt::Display for Config {
|
||||||
"Block non-admin room invites (local and remote, admins can still send and receive invites)",
|
"Block non-admin room invites (local and remote, admins can still send and receive invites)",
|
||||||
&self.block_non_admin_invites.to_string(),
|
&self.block_non_admin_invites.to_string(),
|
||||||
),
|
),
|
||||||
(
|
("Allow device name federation", &self.allow_device_name_federation.to_string()),
|
||||||
"Allow device name federation",
|
|
||||||
&self.allow_device_name_federation.to_string(),
|
|
||||||
),
|
|
||||||
("Notification push path", &self.notification_push_path),
|
("Notification push path", &self.notification_push_path),
|
||||||
("Allow room creation", &self.allow_room_creation.to_string()),
|
("Allow room creation", &self.allow_room_creation.to_string()),
|
||||||
(
|
(
|
||||||
|
@ -356,15 +336,9 @@ impl fmt::Display for Config {
|
||||||
}
|
}
|
||||||
&lst.join(", ")
|
&lst.join(", ")
|
||||||
}),
|
}),
|
||||||
(
|
("zstd Response Body Compression", &self.zstd_compression.to_string()),
|
||||||
"zstd Response Body Compression",
|
|
||||||
&self.zstd_compression.to_string(),
|
|
||||||
),
|
|
||||||
("RocksDB database log level", &self.rocksdb_log_level),
|
("RocksDB database log level", &self.rocksdb_log_level),
|
||||||
(
|
("RocksDB database log time-to-roll", &self.rocksdb_log_time_to_roll.to_string()),
|
||||||
"RocksDB database log time-to-roll",
|
|
||||||
&self.rocksdb_log_time_to_roll.to_string(),
|
|
||||||
),
|
|
||||||
(
|
(
|
||||||
"RocksDB database max log file size",
|
"RocksDB database max log file size",
|
||||||
&self.rocksdb_max_log_file_size.to_string(),
|
&self.rocksdb_max_log_file_size.to_string(),
|
||||||
|
@ -373,10 +347,7 @@ impl fmt::Display for Config {
|
||||||
"RocksDB database optimize for spinning disks",
|
"RocksDB database optimize for spinning disks",
|
||||||
&self.rocksdb_optimize_for_spinning_disks.to_string(),
|
&self.rocksdb_optimize_for_spinning_disks.to_string(),
|
||||||
),
|
),
|
||||||
(
|
("RocksDB Parallelism Threads", &self.rocksdb_parallelism_threads.to_string()),
|
||||||
"RocksDB Parallelism Threads",
|
|
||||||
&self.rocksdb_parallelism_threads.to_string(),
|
|
||||||
),
|
|
||||||
("Prevent Media Downloads From", {
|
("Prevent Media Downloads From", {
|
||||||
let mut lst = vec![];
|
let mut lst = vec![];
|
||||||
for domain in &self.prevent_media_downloads_from {
|
for domain in &self.prevent_media_downloads_from {
|
||||||
|
@ -410,14 +381,8 @@ impl fmt::Display for Config {
|
||||||
"URL preview URL contains allowlist",
|
"URL preview URL contains allowlist",
|
||||||
&self.url_preview_url_contains_allowlist.join(", "),
|
&self.url_preview_url_contains_allowlist.join(", "),
|
||||||
),
|
),
|
||||||
(
|
("URL preview maximum spider size", &self.url_preview_max_spider_size.to_string()),
|
||||||
"URL preview maximum spider size",
|
("URL preview check root domain", &self.url_preview_check_root_domain.to_string()),
|
||||||
&self.url_preview_max_spider_size.to_string(),
|
|
||||||
),
|
|
||||||
(
|
|
||||||
"URL preview check root domain",
|
|
||||||
&self.url_preview_check_root_domain.to_string(),
|
|
||||||
),
|
|
||||||
];
|
];
|
||||||
|
|
||||||
let mut msg: String = "Active config values:\n\n".to_owned();
|
let mut msg: String = "Active config values:\n\n".to_owned();
|
||||||
|
@ -430,13 +395,9 @@ impl fmt::Display for Config {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn true_fn() -> bool {
|
fn true_fn() -> bool { true }
|
||||||
true
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_address() -> IpAddr {
|
fn default_address() -> IpAddr { Ipv4Addr::LOCALHOST.into() }
|
||||||
Ipv4Addr::LOCALHOST.into()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_port() -> ListeningPort {
|
fn default_port() -> ListeningPort {
|
||||||
ListeningPort {
|
ListeningPort {
|
||||||
|
@ -444,25 +405,15 @@ fn default_port() -> ListeningPort {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_unix_socket_perms() -> u32 {
|
fn default_unix_socket_perms() -> u32 { 660 }
|
||||||
660
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_database_backend() -> String {
|
fn default_database_backend() -> String { "rocksdb".to_owned() }
|
||||||
"rocksdb".to_owned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_db_cache_capacity_mb() -> f64 {
|
fn default_db_cache_capacity_mb() -> f64 { 300.0 }
|
||||||
300.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_conduit_cache_capacity_modifier() -> f64 {
|
fn default_conduit_cache_capacity_modifier() -> f64 { 1.0 }
|
||||||
1.0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_pdu_cache_capacity() -> u32 {
|
fn default_pdu_cache_capacity() -> u32 { 150_000 }
|
||||||
150_000
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_cleanup_second_interval() -> u32 {
|
fn default_cleanup_second_interval() -> u32 {
|
||||||
60 // every minute
|
60 // every minute
|
||||||
|
@ -472,54 +423,30 @@ fn default_max_request_size() -> u32 {
|
||||||
20 * 1024 * 1024 // Default to 20 MB
|
20 * 1024 * 1024 // Default to 20 MB
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_max_concurrent_requests() -> u16 {
|
fn default_max_concurrent_requests() -> u16 { 500 }
|
||||||
500
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_max_fetch_prev_events() -> u16 {
|
fn default_max_fetch_prev_events() -> u16 { 100_u16 }
|
||||||
100_u16
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_trusted_servers() -> Vec<OwnedServerName> {
|
fn default_trusted_servers() -> Vec<OwnedServerName> { vec![OwnedServerName::try_from("matrix.org").unwrap()] }
|
||||||
vec![OwnedServerName::try_from("matrix.org").unwrap()]
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_log() -> String {
|
fn default_log() -> String { "warn,state_res=warn".to_owned() }
|
||||||
"warn,state_res=warn".to_owned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_notification_push_path() -> String {
|
fn default_notification_push_path() -> String { "/_matrix/push/v1/notify".to_owned() }
|
||||||
"/_matrix/push/v1/notify".to_owned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_turn_ttl() -> u64 {
|
fn default_turn_ttl() -> u64 { 60 * 60 * 24 }
|
||||||
60 * 60 * 24
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_presence_idle_timeout_s() -> u64 {
|
fn default_presence_idle_timeout_s() -> u64 { 5 * 60 }
|
||||||
5 * 60
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_presence_offline_timeout_s() -> u64 {
|
fn default_presence_offline_timeout_s() -> u64 { 30 * 60 }
|
||||||
30 * 60
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_rocksdb_log_level() -> String {
|
fn default_rocksdb_log_level() -> String { "warn".to_owned() }
|
||||||
"warn".to_owned()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_rocksdb_log_time_to_roll() -> usize {
|
fn default_rocksdb_log_time_to_roll() -> usize { 0 }
|
||||||
0
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_rocksdb_parallelism_threads() -> usize {
|
fn default_rocksdb_parallelism_threads() -> usize { num_cpus::get_physical() / 2 }
|
||||||
num_cpus::get_physical() / 2
|
|
||||||
}
|
|
||||||
|
|
||||||
// I know, it's a great name
|
// I know, it's a great name
|
||||||
pub(crate) fn default_default_room_version() -> RoomVersionId {
|
pub(crate) fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 }
|
||||||
RoomVersionId::V10
|
|
||||||
}
|
|
||||||
|
|
||||||
fn default_rocksdb_max_log_file_size() -> usize {
|
fn default_rocksdb_max_log_file_size() -> usize {
|
||||||
// 4 megabytes
|
// 4 megabytes
|
||||||
|
@ -554,6 +481,4 @@ fn default_url_preview_max_spider_size() -> usize {
|
||||||
1_000_000 // 1MB
|
1_000_000 // 1MB
|
||||||
}
|
}
|
||||||
|
|
||||||
fn default_new_user_displayname_suffix() -> String {
|
fn default_new_user_displayname_suffix() -> String { "🏳️⚧️".to_owned() }
|
||||||
"🏳️⚧️".to_owned()
|
|
||||||
}
|
|
||||||
|
|
|
@ -24,9 +24,10 @@ use crate::Result;
|
||||||
/// ## Include vs. Exclude
|
/// ## Include vs. Exclude
|
||||||
/// If include is an empty list, it is assumed to be `["*"]`.
|
/// If include is an empty list, it is assumed to be `["*"]`.
|
||||||
///
|
///
|
||||||
/// If a domain matches both the exclude and include list, the proxy will only be used if it was
|
/// If a domain matches both the exclude and include list, the proxy will only
|
||||||
/// included because of a more specific rule than it was excluded. In the above example, the proxy
|
/// be used if it was included because of a more specific rule than it was
|
||||||
/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
|
/// excluded. In the above example, the proxy would be used for
|
||||||
|
/// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
|
||||||
#[derive(Clone, Default, Debug, Deserialize)]
|
#[derive(Clone, Default, Debug, Deserialize)]
|
||||||
#[serde(rename_all = "snake_case")]
|
#[serde(rename_all = "snake_case")]
|
||||||
pub enum ProxyConfig {
|
pub enum ProxyConfig {
|
||||||
|
@ -42,9 +43,12 @@ impl ProxyConfig {
|
||||||
pub fn to_proxy(&self) -> Result<Option<Proxy>> {
|
pub fn to_proxy(&self) -> Result<Option<Proxy>> {
|
||||||
Ok(match self.clone() {
|
Ok(match self.clone() {
|
||||||
ProxyConfig::None => None,
|
ProxyConfig::None => None,
|
||||||
ProxyConfig::Global { url } => Some(Proxy::all(url)?),
|
ProxyConfig::Global {
|
||||||
|
url,
|
||||||
|
} => Some(Proxy::all(url)?),
|
||||||
ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
|
ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
|
||||||
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
|
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching
|
||||||
|
// proxy
|
||||||
})),
|
})),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -85,7 +89,8 @@ impl PartialProxyConfig {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
match (included_because, excluded_because) {
|
match (included_because, excluded_because) {
|
||||||
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
|
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), /* included for a more specific reason */
|
||||||
|
// than excluded
|
||||||
(Some(_), None) => Some(&self.url),
|
(Some(_), None) => Some(&self.url),
|
||||||
_ => None,
|
_ => None,
|
||||||
}
|
}
|
||||||
|
@ -107,20 +112,20 @@ impl WildCardedDomain {
|
||||||
WildCardedDomain::Exact(d) => domain == d,
|
WildCardedDomain::Exact(d) => domain == d,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn more_specific_than(&self, other: &Self) -> bool {
|
fn more_specific_than(&self, other: &Self) -> bool {
|
||||||
match (self, other) {
|
match (self, other) {
|
||||||
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
|
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
|
||||||
(_, WildCardedDomain::WildCard) => true,
|
(_, WildCardedDomain::WildCard) => true,
|
||||||
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
|
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
|
||||||
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
|
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => a != b && a.ends_with(b),
|
||||||
a != b && a.ends_with(b)
|
|
||||||
}
|
|
||||||
_ => false,
|
_ => false,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
impl std::str::FromStr for WildCardedDomain {
|
impl std::str::FromStr for WildCardedDomain {
|
||||||
type Err = std::convert::Infallible;
|
type Err = std::convert::Infallible;
|
||||||
|
|
||||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||||
// maybe do some domain validation?
|
// maybe do some domain validation?
|
||||||
Ok(if s.starts_with("*.") {
|
Ok(if s.starts_with("*.") {
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
|
use std::{future::Future, pin::Pin, sync::Arc};
|
||||||
|
|
||||||
use super::Config;
|
use super::Config;
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
|
|
||||||
use std::{future::Future, pin::Pin, sync::Arc};
|
|
||||||
|
|
||||||
#[cfg(feature = "sqlite")]
|
#[cfg(feature = "sqlite")]
|
||||||
pub mod sqlite;
|
pub mod sqlite;
|
||||||
|
|
||||||
|
@ -18,9 +18,7 @@ pub(crate) trait KeyValueDatabaseEngine: Send + Sync {
|
||||||
Self: Sized;
|
Self: Sized;
|
||||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>;
|
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>;
|
||||||
fn flush(&self) -> Result<()>;
|
fn flush(&self) -> Result<()>;
|
||||||
fn cleanup(&self) -> Result<()> {
|
fn cleanup(&self) -> Result<()> { Ok(()) }
|
||||||
Ok(())
|
|
||||||
}
|
|
||||||
fn memory_usage(&self) -> Result<String> {
|
fn memory_usage(&self) -> Result<String> {
|
||||||
Ok("Current database engine does not support memory usage reporting.".to_owned())
|
Ok("Current database engine does not support memory usage reporting.".to_owned())
|
||||||
}
|
}
|
||||||
|
@ -39,19 +37,12 @@ pub(crate) trait KvTree: Send + Sync {
|
||||||
|
|
||||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||||
|
|
||||||
fn iter_from<'a>(
|
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||||
&'a self,
|
|
||||||
from: &[u8],
|
|
||||||
backwards: bool,
|
|
||||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
|
||||||
|
|
||||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
|
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
|
||||||
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
|
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
|
||||||
|
|
||||||
fn scan_prefix<'a>(
|
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||||
&'a self,
|
|
||||||
prefix: Vec<u8>,
|
|
||||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
|
||||||
|
|
||||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
|
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
|
||||||
|
|
||||||
|
|
|
@ -7,9 +7,8 @@ use std::{
|
||||||
use rocksdb::LogLevel::{Debug, Error, Fatal, Info, Warn};
|
use rocksdb::LogLevel::{Debug, Error, Fatal, Info, Warn};
|
||||||
use tracing::{debug, info};
|
use tracing::{debug, info};
|
||||||
|
|
||||||
use crate::{utils, Result};
|
|
||||||
|
|
||||||
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||||
|
use crate::{utils, Result};
|
||||||
|
|
||||||
pub(crate) struct Engine {
|
pub(crate) struct Engine {
|
||||||
rocks: rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>,
|
rocks: rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>,
|
||||||
|
@ -62,7 +61,9 @@ fn db_options(rocksdb_cache: &rocksdb::Cache, config: &Config) -> rocksdb::Optio
|
||||||
db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs
|
db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs
|
||||||
db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs
|
db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs
|
||||||
db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage
|
db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage
|
||||||
db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for spinning hard drives. these are not really important
|
db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for
|
||||||
|
// spinning hard drives. these are not really
|
||||||
|
// important
|
||||||
} else {
|
} else {
|
||||||
db_opts.set_skip_stats_update_on_db_open(false);
|
db_opts.set_skip_stats_update_on_db_open(false);
|
||||||
db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024);
|
db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024);
|
||||||
|
@ -75,9 +76,7 @@ fn db_options(rocksdb_cache: &rocksdb::Cache, config: &Config) -> rocksdb::Optio
|
||||||
db_opts.set_level_compaction_dynamic_level_bytes(true);
|
db_opts.set_level_compaction_dynamic_level_bytes(true);
|
||||||
db_opts.create_if_missing(true);
|
db_opts.create_if_missing(true);
|
||||||
db_opts.increase_parallelism(
|
db_opts.increase_parallelism(
|
||||||
threads
|
threads.try_into().expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"),
|
||||||
.try_into()
|
|
||||||
.expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"),
|
|
||||||
);
|
);
|
||||||
//db_opts.set_max_open_files(config.rocksdb_max_open_files);
|
//db_opts.set_max_open_files(config.rocksdb_max_open_files);
|
||||||
db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd);
|
db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd);
|
||||||
|
@ -109,10 +108,7 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||||
let db_opts = db_options(&rocksdb_cache, config);
|
let db_opts = db_options(&rocksdb_cache, config);
|
||||||
|
|
||||||
debug!("Listing column families in database");
|
debug!("Listing column families in database");
|
||||||
let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf(
|
let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf(&db_opts, &config.database_path)
|
||||||
&db_opts,
|
|
||||||
&config.database_path,
|
|
||||||
)
|
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
debug!("Opening column family descriptors in database");
|
debug!("Opening column family descriptors in database");
|
||||||
|
@ -120,9 +116,7 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||||
let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors(
|
let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors(
|
||||||
&db_opts,
|
&db_opts,
|
||||||
&config.database_path,
|
&config.database_path,
|
||||||
cfs.iter().map(|name| {
|
cfs.iter().map(|name| rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config))),
|
||||||
rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config))
|
|
||||||
}),
|
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
Ok(Arc::new(Engine {
|
Ok(Arc::new(Engine {
|
||||||
|
@ -137,9 +131,7 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||||
if !self.old_cfs.contains(&name.to_owned()) {
|
if !self.old_cfs.contains(&name.to_owned()) {
|
||||||
// Create if it didn't exist
|
// Create if it didn't exist
|
||||||
debug!("Creating new column family in database: {}", name);
|
debug!("Creating new column family in database: {}", name);
|
||||||
let _ = self
|
let _ = self.rocks.create_cf(name, &db_options(&self.cache, &self.config));
|
||||||
.rocks
|
|
||||||
.create_cf(name, &db_options(&self.cache, &self.config));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(Arc::new(RocksDbEngineTree {
|
Ok(Arc::new(RocksDbEngineTree {
|
||||||
|
@ -156,15 +148,11 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn memory_usage(&self) -> Result<String> {
|
fn memory_usage(&self) -> Result<String> {
|
||||||
let stats =
|
let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
|
||||||
rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
|
|
||||||
Ok(format!(
|
Ok(format!(
|
||||||
"Approximate memory usage of all the mem-tables: {:.3} MB\n\
|
"Approximate memory usage of all the mem-tables: {:.3} MB\nApproximate memory usage of un-flushed \
|
||||||
Approximate memory usage of un-flushed mem-tables: {:.3} MB\n\
|
mem-tables: {:.3} MB\nApproximate memory usage of all the table readers: {:.3} MB\nApproximate memory \
|
||||||
Approximate memory usage of all the table readers: {:.3} MB\n\
|
usage by cache: {:.3} MB\nApproximate memory usage by cache pinned: {:.3} MB\n",
|
||||||
Approximate memory usage by cache: {:.3} MB\n\
|
|
||||||
Approximate memory usage by cache pinned: {:.3} MB\n\
|
|
||||||
",
|
|
||||||
stats.mem_table_total as f64 / 1024.0 / 1024.0,
|
stats.mem_table_total as f64 / 1024.0 / 1024.0,
|
||||||
stats.mem_table_unflushed as f64 / 1024.0 / 1024.0,
|
stats.mem_table_unflushed as f64 / 1024.0 / 1024.0,
|
||||||
stats.mem_table_readers_total as f64 / 1024.0 / 1024.0,
|
stats.mem_table_readers_total as f64 / 1024.0 / 1024.0,
|
||||||
|
@ -179,15 +167,11 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl RocksDbEngineTree<'_> {
|
impl RocksDbEngineTree<'_> {
|
||||||
fn cf(&self) -> Arc<rocksdb::BoundColumnFamily<'_>> {
|
fn cf(&self) -> Arc<rocksdb::BoundColumnFamily<'_>> { self.db.rocks.cf_handle(self.name).unwrap() }
|
||||||
self.db.rocks.cf_handle(self.name).unwrap()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl KvTree for RocksDbEngineTree<'_> {
|
impl KvTree for RocksDbEngineTree<'_> {
|
||||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { Ok(self.db.rocks.get_cf(&self.cf(), key)?) }
|
||||||
Ok(self.db.rocks.get_cf(&self.cf(), key)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||||
let lock = self.write_lock.read().unwrap();
|
let lock = self.write_lock.read().unwrap();
|
||||||
|
@ -207,9 +191,7 @@ impl KvTree for RocksDbEngineTree<'_> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
fn remove(&self, key: &[u8]) -> Result<()> { Ok(self.db.rocks.delete_cf(&self.cf(), key)?) }
|
||||||
Ok(self.db.rocks.delete_cf(&self.cf(), key)?)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||||
Box::new(
|
Box::new(
|
||||||
|
@ -221,11 +203,7 @@ impl KvTree for RocksDbEngineTree<'_> {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn iter_from<'a>(
|
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||||
&'a self,
|
|
||||||
from: &[u8],
|
|
||||||
backwards: bool,
|
|
||||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
|
||||||
Box::new(
|
Box::new(
|
||||||
self.db
|
self.db
|
||||||
.rocks
|
.rocks
|
||||||
|
@ -270,17 +248,11 @@ impl KvTree for RocksDbEngineTree<'_> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scan_prefix<'a>(
|
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||||
&'a self,
|
|
||||||
prefix: Vec<u8>,
|
|
||||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
|
||||||
Box::new(
|
Box::new(
|
||||||
self.db
|
self.db
|
||||||
.rocks
|
.rocks
|
||||||
.iterator_cf(
|
.iterator_cf(&self.cf(), rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward))
|
||||||
&self.cf(),
|
|
||||||
rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward),
|
|
||||||
)
|
|
||||||
.map(std::result::Result::unwrap)
|
.map(std::result::Result::unwrap)
|
||||||
.map(|(k, v)| (Vec::from(k), Vec::from(v)))
|
.map(|(k, v)| (Vec::from(k), Vec::from(v)))
|
||||||
.take_while(move |(k, _)| k.starts_with(&prefix)),
|
.take_while(move |(k, _)| k.starts_with(&prefix)),
|
||||||
|
|
|
@ -1,7 +1,3 @@
|
||||||
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
|
||||||
use crate::{database::Config, Result};
|
|
||||||
use parking_lot::{Mutex, MutexGuard};
|
|
||||||
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
|
|
||||||
use std::{
|
use std::{
|
||||||
cell::RefCell,
|
cell::RefCell,
|
||||||
future::Future,
|
future::Future,
|
||||||
|
@ -9,9 +5,15 @@ use std::{
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::Arc,
|
sync::Arc,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use parking_lot::{Mutex, MutexGuard};
|
||||||
|
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
|
||||||
use thread_local::ThreadLocal;
|
use thread_local::ThreadLocal;
|
||||||
use tracing::debug;
|
use tracing::debug;
|
||||||
|
|
||||||
|
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||||
|
use crate::{database::Config, Result};
|
||||||
|
|
||||||
thread_local! {
|
thread_local! {
|
||||||
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||||
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||||
|
@ -25,9 +27,7 @@ struct PreparedStatementIterator<'a> {
|
||||||
impl Iterator for PreparedStatementIterator<'_> {
|
impl Iterator for PreparedStatementIterator<'_> {
|
||||||
type Item = TupleOfBytes;
|
type Item = TupleOfBytes;
|
||||||
|
|
||||||
fn next(&mut self) -> Option<Self::Item> {
|
fn next(&mut self) -> Option<Self::Item> { self.iterator.next() }
|
||||||
self.iterator.next()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct NonAliasingBox<T>(*mut T);
|
struct NonAliasingBox<T>(*mut T);
|
||||||
|
@ -61,23 +61,18 @@ impl Engine {
|
||||||
Ok(conn)
|
Ok(conn)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn write_lock(&self) -> MutexGuard<'_, Connection> {
|
fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() }
|
||||||
self.writer.lock()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn read_lock(&self) -> &Connection {
|
fn read_lock(&self) -> &Connection {
|
||||||
self.read_conn_tls
|
self.read_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||||
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn read_lock_iterator(&self) -> &Connection {
|
fn read_lock_iterator(&self) -> &Connection {
|
||||||
self.read_iterator_conn_tls
|
self.read_iterator_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||||
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
|
pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
|
||||||
self.write_lock()
|
self.write_lock().pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
|
||||||
.pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -88,11 +83,11 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||||
|
|
||||||
// calculates cache-size per permanent connection
|
// calculates cache-size per permanent connection
|
||||||
// 1. convert MB to KiB
|
// 1. convert MB to KiB
|
||||||
// 2. divide by permanent connections + permanent iter connections + write connection
|
// 2. divide by permanent connections + permanent iter connections + write
|
||||||
|
// connection
|
||||||
// 3. round down to nearest integer
|
// 3. round down to nearest integer
|
||||||
let cache_size_per_thread: u32 = ((config.db_cache_capacity_mb * 1024.0)
|
let cache_size_per_thread: u32 =
|
||||||
/ ((num_cpus::get().max(1) * 2) + 1) as f64)
|
((config.db_cache_capacity_mb * 1024.0) / ((num_cpus::get().max(1) * 2) + 1) as f64) as u32;
|
||||||
as u32;
|
|
||||||
|
|
||||||
let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
|
let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
|
||||||
|
|
||||||
|
@ -108,7 +103,10 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
|
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
|
||||||
self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?;
|
self.write_lock().execute(
|
||||||
|
&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"),
|
||||||
|
[],
|
||||||
|
)?;
|
||||||
|
|
||||||
Ok(Arc::new(SqliteTable {
|
Ok(Arc::new(SqliteTable {
|
||||||
engine: Arc::clone(self),
|
engine: Arc::clone(self),
|
||||||
|
@ -122,9 +120,7 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cleanup(&self) -> Result<()> {
|
fn cleanup(&self) -> Result<()> { self.flush_wal() }
|
||||||
self.flush_wal()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub struct SqliteTable {
|
pub struct SqliteTable {
|
||||||
|
@ -145,27 +141,15 @@ impl SqliteTable {
|
||||||
|
|
||||||
fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> {
|
fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> {
|
||||||
guard.execute(
|
guard.execute(
|
||||||
format!(
|
format!("INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name).as_str(),
|
||||||
"INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)",
|
|
||||||
self.name
|
|
||||||
)
|
|
||||||
.as_str(),
|
|
||||||
[key, value],
|
[key, value],
|
||||||
)?;
|
)?;
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn iter_with_guard<'a>(
|
pub fn iter_with_guard<'a>(&'a self, guard: &'a Connection) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||||
&'a self,
|
|
||||||
guard: &'a Connection,
|
|
||||||
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
|
||||||
let statement = Box::leak(Box::new(
|
let statement = Box::leak(Box::new(
|
||||||
guard
|
guard.prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)).unwrap(),
|
||||||
.prepare(&format!(
|
|
||||||
"SELECT key, value FROM {} ORDER BY key ASC",
|
|
||||||
&self.name
|
|
||||||
))
|
|
||||||
.unwrap(),
|
|
||||||
));
|
));
|
||||||
|
|
||||||
let statement_ref = NonAliasingBox(statement);
|
let statement_ref = NonAliasingBox(statement);
|
||||||
|
@ -173,10 +157,7 @@ impl SqliteTable {
|
||||||
//let name = self.name.clone();
|
//let name = self.name.clone();
|
||||||
|
|
||||||
let iterator = Box::new(
|
let iterator = Box::new(
|
||||||
statement
|
statement.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))).unwrap().map(move |r| r.unwrap()),
|
||||||
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
|
||||||
.unwrap()
|
|
||||||
.map(move |r| r.unwrap()),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
Box::new(PreparedStatementIterator {
|
Box::new(PreparedStatementIterator {
|
||||||
|
@ -187,9 +168,7 @@ impl SqliteTable {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl KvTree for SqliteTable {
|
impl KvTree for SqliteTable {
|
||||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { self.get_with_guard(self.engine.read_lock(), key) }
|
||||||
self.get_with_guard(self.engine.read_lock(), key)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||||
let guard = self.engine.write_lock();
|
let guard = self.engine.write_lock();
|
||||||
|
@ -219,8 +198,7 @@ impl KvTree for SqliteTable {
|
||||||
guard.execute("BEGIN", [])?;
|
guard.execute("BEGIN", [])?;
|
||||||
for key in iter {
|
for key in iter {
|
||||||
let old = self.get_with_guard(&guard, &key)?;
|
let old = self.get_with_guard(&guard, &key)?;
|
||||||
let new = crate::utils::increment(old.as_deref())
|
let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||||
.expect("utils::increment always returns Some");
|
|
||||||
self.insert_with_guard(&guard, &key, &new)?;
|
self.insert_with_guard(&guard, &key, &new)?;
|
||||||
}
|
}
|
||||||
guard.execute("COMMIT", [])?;
|
guard.execute("COMMIT", [])?;
|
||||||
|
@ -233,10 +211,7 @@ impl KvTree for SqliteTable {
|
||||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||||
let guard = self.engine.write_lock();
|
let guard = self.engine.write_lock();
|
||||||
|
|
||||||
guard.execute(
|
guard.execute(format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key])?;
|
||||||
format!("DELETE FROM {} WHERE key = ?", self.name).as_str(),
|
|
||||||
[key],
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -247,11 +222,7 @@ impl KvTree for SqliteTable {
|
||||||
self.iter_with_guard(guard)
|
self.iter_with_guard(guard)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn iter_from<'a>(
|
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||||
&'a self,
|
|
||||||
from: &[u8],
|
|
||||||
backwards: bool,
|
|
||||||
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
|
||||||
let guard = self.engine.read_lock_iterator();
|
let guard = self.engine.read_lock_iterator();
|
||||||
let from = from.to_vec(); // TODO change interface?
|
let from = from.to_vec(); // TODO change interface?
|
||||||
|
|
||||||
|
@ -310,8 +281,7 @@ impl KvTree for SqliteTable {
|
||||||
|
|
||||||
let old = self.get_with_guard(&guard, key)?;
|
let old = self.get_with_guard(&guard, key)?;
|
||||||
|
|
||||||
let new =
|
let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||||
crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
|
||||||
|
|
||||||
self.insert_with_guard(&guard, key, &new)?;
|
self.insert_with_guard(&guard, key, &new)?;
|
||||||
|
|
||||||
|
@ -319,10 +289,7 @@ impl KvTree for SqliteTable {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||||
Box::new(
|
Box::new(self.iter_from(&prefix, false).take_while(move |(key, _)| key.starts_with(&prefix)))
|
||||||
self.iter_from(&prefix, false)
|
|
||||||
.take_while(move |(key, _)| key.starts_with(&prefix)),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||||
|
@ -331,9 +298,7 @@ impl KvTree for SqliteTable {
|
||||||
|
|
||||||
fn clear(&self) -> Result<()> {
|
fn clear(&self) -> Result<()> {
|
||||||
debug!("clear: running");
|
debug!("clear: running");
|
||||||
self.engine
|
self.engine.write_lock().execute(format!("DELETE FROM {}", self.name).as_str(), [])?;
|
||||||
.write_lock()
|
|
||||||
.execute(format!("DELETE FROM {}", self.name).as_str(), [])?;
|
|
||||||
debug!("clear: ran");
|
debug!("clear: ran");
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ use std::{
|
||||||
pin::Pin,
|
pin::Pin,
|
||||||
sync::RwLock,
|
sync::RwLock,
|
||||||
};
|
};
|
||||||
|
|
||||||
use tokio::sync::watch;
|
use tokio::sync::watch;
|
||||||
|
|
||||||
type Watcher = RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>;
|
type Watcher = RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>;
|
||||||
|
@ -14,17 +15,14 @@ pub(super) struct Watchers {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Watchers {
|
impl Watchers {
|
||||||
pub(super) fn watch<'a>(
|
pub(super) fn watch<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||||
&'a self,
|
|
||||||
prefix: &[u8],
|
|
||||||
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
|
||||||
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) {
|
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) {
|
||||||
hash_map::Entry::Occupied(o) => o.get().1.clone(),
|
hash_map::Entry::Occupied(o) => o.get().1.clone(),
|
||||||
hash_map::Entry::Vacant(v) => {
|
hash_map::Entry::Vacant(v) => {
|
||||||
let (tx, rx) = tokio::sync::watch::channel(());
|
let (tx, rx) = tokio::sync::watch::channel(());
|
||||||
v.insert((tx, rx.clone()));
|
v.insert((tx, rx.clone()));
|
||||||
rx
|
rx
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
Box::pin(async move {
|
Box::pin(async move {
|
||||||
|
@ -32,6 +30,7 @@ impl Watchers {
|
||||||
rx.changed().await.unwrap();
|
rx.changed().await.unwrap();
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
pub(super) fn wake(&self, key: &[u8]) {
|
pub(super) fn wake(&self, key: &[u8]) {
|
||||||
let watchers = self.watchers.read().unwrap();
|
let watchers = self.watchers.read().unwrap();
|
||||||
let mut triggered = Vec::new();
|
let mut triggered = Vec::new();
|
||||||
|
|
|
@ -11,27 +11,21 @@ use tracing::warn;
|
||||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||||
|
|
||||||
impl service::account_data::Data for KeyValueDatabase {
|
impl service::account_data::Data for KeyValueDatabase {
|
||||||
/// Places one event in the account data of the user and removes the previous entry.
|
/// Places one event in the account data of the user and removes the
|
||||||
|
/// previous entry.
|
||||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||||
fn update(
|
fn update(
|
||||||
&self,
|
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||||
room_id: Option<&RoomId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
event_type: RoomAccountDataEventType,
|
|
||||||
data: &serde_json::Value,
|
data: &serde_json::Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut prefix = room_id
|
let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||||
.map(std::string::ToString::to_string)
|
prefix.push(0xFF);
|
||||||
.unwrap_or_default()
|
|
||||||
.as_bytes()
|
|
||||||
.to_vec();
|
|
||||||
prefix.push(0xff);
|
|
||||||
prefix.extend_from_slice(user_id.as_bytes());
|
prefix.extend_from_slice(user_id.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let mut roomuserdataid = prefix.clone();
|
let mut roomuserdataid = prefix.clone();
|
||||||
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||||
roomuserdataid.push(0xff);
|
roomuserdataid.push(0xFF);
|
||||||
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
|
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
|
||||||
|
|
||||||
let mut key = prefix;
|
let mut key = prefix;
|
||||||
|
@ -51,8 +45,7 @@ impl service::account_data::Data for KeyValueDatabase {
|
||||||
|
|
||||||
let prev = self.roomusertype_roomuserdataid.get(&key)?;
|
let prev = self.roomusertype_roomuserdataid.get(&key)?;
|
||||||
|
|
||||||
self.roomusertype_roomuserdataid
|
self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
|
||||||
.insert(&key, &roomuserdataid)?;
|
|
||||||
|
|
||||||
// Remove old entry
|
// Remove old entry
|
||||||
if let Some(prev) = prev {
|
if let Some(prev) = prev {
|
||||||
|
@ -65,54 +58,33 @@ impl service::account_data::Data for KeyValueDatabase {
|
||||||
/// Searches the account data for a specific kind.
|
/// Searches the account data for a specific kind.
|
||||||
#[tracing::instrument(skip(self, room_id, user_id, kind))]
|
#[tracing::instrument(skip(self, room_id, user_id, kind))]
|
||||||
fn get(
|
fn get(
|
||||||
&self,
|
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
|
||||||
room_id: Option<&RoomId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
kind: RoomAccountDataEventType,
|
|
||||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||||
let mut key = room_id
|
let mut key = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||||
.map(std::string::ToString::to_string)
|
key.push(0xFF);
|
||||||
.unwrap_or_default()
|
|
||||||
.as_bytes()
|
|
||||||
.to_vec();
|
|
||||||
key.push(0xff);
|
|
||||||
key.extend_from_slice(user_id.as_bytes());
|
key.extend_from_slice(user_id.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(kind.to_string().as_bytes());
|
key.extend_from_slice(kind.to_string().as_bytes());
|
||||||
|
|
||||||
self.roomusertype_roomuserdataid
|
self.roomusertype_roomuserdataid
|
||||||
.get(&key)?
|
.get(&key)?
|
||||||
.and_then(|roomuserdataid| {
|
.and_then(|roomuserdataid| self.roomuserdataid_accountdata.get(&roomuserdataid).transpose())
|
||||||
self.roomuserdataid_accountdata
|
|
||||||
.get(&roomuserdataid)
|
|
||||||
.transpose()
|
|
||||||
})
|
|
||||||
.transpose()?
|
.transpose()?
|
||||||
.map(|data| {
|
.map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
|
||||||
serde_json::from_slice(&data)
|
|
||||||
.map_err(|_| Error::bad_database("could not deserialize"))
|
|
||||||
})
|
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns all changes to the account data that happened after `since`.
|
/// Returns all changes to the account data that happened after `since`.
|
||||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||||
fn changes_since(
|
fn changes_since(
|
||||||
&self,
|
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
|
||||||
room_id: Option<&RoomId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
since: u64,
|
|
||||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||||
let mut userdata = HashMap::new();
|
let mut userdata = HashMap::new();
|
||||||
|
|
||||||
let mut prefix = room_id
|
let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||||
.map(std::string::ToString::to_string)
|
prefix.push(0xFF);
|
||||||
.unwrap_or_default()
|
|
||||||
.as_bytes()
|
|
||||||
.to_vec();
|
|
||||||
prefix.push(0xff);
|
|
||||||
prefix.extend_from_slice(user_id.as_bytes());
|
prefix.extend_from_slice(user_id.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
// Skip the data that's exactly at since, because we sent that last time
|
// Skip the data that's exactly at since, because we sent that last time
|
||||||
let mut first_possible = prefix.clone();
|
let mut first_possible = prefix.clone();
|
||||||
|
@ -125,20 +97,20 @@ impl service::account_data::Data for KeyValueDatabase {
|
||||||
.map(|(k, v)| {
|
.map(|(k, v)| {
|
||||||
Ok::<_, Error>((
|
Ok::<_, Error>((
|
||||||
RoomAccountDataEventType::from(
|
RoomAccountDataEventType::from(
|
||||||
utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else(
|
utils::string_from_bytes(
|
||||||
|| Error::bad_database("RoomUserData ID in db is invalid."),
|
k.rsplit(|&b| b == 0xFF)
|
||||||
)?)
|
.next()
|
||||||
|
.ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?,
|
||||||
|
)
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
warn!("RoomUserData ID in database is invalid: {}", e);
|
warn!("RoomUserData ID in database is invalid: {}", e);
|
||||||
Error::bad_database("RoomUserData ID in db is invalid.")
|
Error::bad_database("RoomUserData ID in db is invalid.")
|
||||||
})?,
|
})?,
|
||||||
),
|
),
|
||||||
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v).map_err(|_| {
|
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
|
||||||
Error::bad_database("Database contains invalid account data.")
|
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
|
||||||
})?,
|
|
||||||
))
|
))
|
||||||
})
|
}) {
|
||||||
{
|
|
||||||
let (kind, data) = r?;
|
let (kind, data) = r?;
|
||||||
userdata.insert(kind, data);
|
userdata.insert(kind, data);
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,14 +6,8 @@ impl service::appservice::Data for KeyValueDatabase {
|
||||||
/// Registers an appservice and returns the ID to the caller
|
/// Registers an appservice and returns the ID to the caller
|
||||||
fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||||
let id = yaml.id.as_str();
|
let id = yaml.id.as_str();
|
||||||
self.id_appserviceregistrations.insert(
|
self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
|
||||||
id.as_bytes(),
|
self.cached_registrations.write().unwrap().insert(id.to_owned(), yaml.clone());
|
||||||
serde_yaml::to_string(&yaml).unwrap().as_bytes(),
|
|
||||||
)?;
|
|
||||||
self.cached_registrations
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.insert(id.to_owned(), yaml.clone());
|
|
||||||
|
|
||||||
Ok(id.to_owned())
|
Ok(id.to_owned())
|
||||||
}
|
}
|
||||||
|
@ -24,29 +18,19 @@ impl service::appservice::Data for KeyValueDatabase {
|
||||||
///
|
///
|
||||||
/// * `service_name` - the name you send to register the service previously
|
/// * `service_name` - the name you send to register the service previously
|
||||||
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||||
self.id_appserviceregistrations
|
self.id_appserviceregistrations.remove(service_name.as_bytes())?;
|
||||||
.remove(service_name.as_bytes())?;
|
self.cached_registrations.write().unwrap().remove(service_name);
|
||||||
self.cached_registrations
|
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.remove(service_name);
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
||||||
self.cached_registrations
|
self.cached_registrations.read().unwrap().get(id).map_or_else(
|
||||||
.read()
|
|
||||||
.unwrap()
|
|
||||||
.get(id)
|
|
||||||
.map_or_else(
|
|
||||||
|| {
|
|| {
|
||||||
self.id_appserviceregistrations
|
self.id_appserviceregistrations
|
||||||
.get(id.as_bytes())?
|
.get(id.as_bytes())?
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
serde_yaml::from_slice(&bytes).map_err(|_| {
|
serde_yaml::from_slice(&bytes).map_err(|_| {
|
||||||
Error::bad_database(
|
Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")
|
||||||
"Invalid registration bytes in id_appserviceregistrations.",
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.transpose()
|
.transpose()
|
||||||
|
@ -56,13 +40,10 @@ impl service::appservice::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
|
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
|
||||||
Ok(Box::new(self.id_appserviceregistrations.iter().map(
|
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
|
||||||
|(id, _)| {
|
utils::string_from_bytes(&id)
|
||||||
utils::string_from_bytes(&id).map_err(|_| {
|
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
|
||||||
Error::bad_database("Invalid id bytes in id_appserviceregistrations.")
|
})))
|
||||||
})
|
|
||||||
},
|
|
||||||
)))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn all(&self) -> Result<Vec<(String, Registration)>> {
|
fn all(&self) -> Result<Vec<(String, Registration)>> {
|
||||||
|
@ -71,8 +52,7 @@ impl service::appservice::Data for KeyValueDatabase {
|
||||||
.map(move |id| {
|
.map(move |id| {
|
||||||
Ok((
|
Ok((
|
||||||
id.clone(),
|
id.clone(),
|
||||||
self.get_registration(&id)?
|
self.get_registration(&id)?.expect("iter_ids only returns appservices that exist"),
|
||||||
.expect("iter_ids only returns appservices that exist"),
|
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
|
|
|
@ -23,24 +23,19 @@ impl service::globals::Data for KeyValueDatabase {
|
||||||
|
|
||||||
fn current_count(&self) -> Result<u64> {
|
fn current_count(&self) -> Result<u64> {
|
||||||
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||||
utils::u64_from_bytes(&bytes)
|
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn last_check_for_updates_id(&self) -> Result<u64> {
|
fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||||
self.global
|
self.global.get(LAST_CHECK_FOR_UPDATES_COUNT)?.map_or(Ok(0_u64), |bytes| {
|
||||||
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
|
utils::u64_from_bytes(&bytes)
|
||||||
.map_or(Ok(0_u64), |bytes| {
|
.map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
|
||||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
|
||||||
Error::bad_database("last check for updates count has invalid bytes.")
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
||||||
self.global
|
self.global.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
|
||||||
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -48,11 +43,11 @@ impl service::globals::Data for KeyValueDatabase {
|
||||||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||||
let userid_bytes = user_id.as_bytes().to_vec();
|
let userid_bytes = user_id.as_bytes().to_vec();
|
||||||
let mut userid_prefix = userid_bytes.clone();
|
let mut userid_prefix = userid_bytes.clone();
|
||||||
userid_prefix.push(0xff);
|
userid_prefix.push(0xFF);
|
||||||
|
|
||||||
let mut userdeviceid_prefix = userid_prefix.clone();
|
let mut userdeviceid_prefix = userid_prefix.clone();
|
||||||
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
|
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
|
||||||
userdeviceid_prefix.push(0xff);
|
userdeviceid_prefix.push(0xFF);
|
||||||
|
|
||||||
let mut futures = FuturesUnordered::new();
|
let mut futures = FuturesUnordered::new();
|
||||||
|
|
||||||
|
@ -63,19 +58,11 @@ impl service::globals::Data for KeyValueDatabase {
|
||||||
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
|
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
|
||||||
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
|
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
|
||||||
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
|
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
|
||||||
futures.push(
|
futures.push(self.userroomid_notificationcount.watch_prefix(&userid_prefix));
|
||||||
self.userroomid_notificationcount
|
|
||||||
.watch_prefix(&userid_prefix),
|
|
||||||
);
|
|
||||||
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
|
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
|
||||||
|
|
||||||
// Events for rooms we are in
|
// Events for rooms we are in
|
||||||
for room_id in services()
|
for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(std::result::Result::ok) {
|
||||||
.rooms
|
|
||||||
.state_cache
|
|
||||||
.rooms_joined(user_id)
|
|
||||||
.filter_map(std::result::Result::ok)
|
|
||||||
{
|
|
||||||
let short_roomid = services()
|
let short_roomid = services()
|
||||||
.rooms
|
.rooms
|
||||||
.short
|
.short
|
||||||
|
@ -88,7 +75,7 @@ impl service::globals::Data for KeyValueDatabase {
|
||||||
|
|
||||||
let roomid_bytes = room_id.as_bytes().to_vec();
|
let roomid_bytes = room_id.as_bytes().to_vec();
|
||||||
let mut roomid_prefix = roomid_bytes.clone();
|
let mut roomid_prefix = roomid_bytes.clone();
|
||||||
roomid_prefix.push(0xff);
|
roomid_prefix.push(0xFF);
|
||||||
|
|
||||||
// PDUs
|
// PDUs
|
||||||
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
|
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
|
||||||
|
@ -105,19 +92,13 @@ impl service::globals::Data for KeyValueDatabase {
|
||||||
let mut roomuser_prefix = roomid_prefix.clone();
|
let mut roomuser_prefix = roomid_prefix.clone();
|
||||||
roomuser_prefix.extend_from_slice(&userid_prefix);
|
roomuser_prefix.extend_from_slice(&userid_prefix);
|
||||||
|
|
||||||
futures.push(
|
futures.push(self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix));
|
||||||
self.roomusertype_roomuserdataid
|
|
||||||
.watch_prefix(&roomuser_prefix),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut globaluserdata_prefix = vec![0xff];
|
let mut globaluserdata_prefix = vec![0xFF];
|
||||||
globaluserdata_prefix.extend_from_slice(&userid_prefix);
|
globaluserdata_prefix.extend_from_slice(&userid_prefix);
|
||||||
|
|
||||||
futures.push(
|
futures.push(self.roomusertype_roomuserdataid.watch_prefix(&globaluserdata_prefix));
|
||||||
self.roomusertype_roomuserdataid
|
|
||||||
.watch_prefix(&globaluserdata_prefix),
|
|
||||||
);
|
|
||||||
|
|
||||||
// More key changes (used when user is not joined to any rooms)
|
// More key changes (used when user is not joined to any rooms)
|
||||||
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
|
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
|
||||||
|
@ -133,9 +114,7 @@ impl service::globals::Data for KeyValueDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn cleanup(&self) -> Result<()> {
|
fn cleanup(&self) -> Result<()> { self.db.cleanup() }
|
||||||
self.db.cleanup()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn memory_usage(&self) -> String {
|
fn memory_usage(&self) -> String {
|
||||||
let pdu_cache = self.pdu_cache.lock().unwrap().len();
|
let pdu_cache = self.pdu_cache.lock().unwrap().len();
|
||||||
|
@ -210,13 +189,11 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|
||||||
Ok,
|
Ok,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
|
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
|
||||||
|
|
||||||
utils::string_from_bytes(
|
utils::string_from_bytes(
|
||||||
// 1. version
|
// 1. version
|
||||||
parts
|
parts.next().expect("splitn always returns at least one element"),
|
||||||
.next()
|
|
||||||
.expect("splitn always returns at least one element"),
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
|
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
|
||||||
.and_then(|version| {
|
.and_then(|version| {
|
||||||
|
@ -231,21 +208,16 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|
||||||
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
|
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
fn remove_keypair(&self) -> Result<()> {
|
|
||||||
self.global.remove(b"keypair")
|
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
|
||||||
}
|
|
||||||
|
|
||||||
fn add_signing_key(
|
fn add_signing_key(
|
||||||
&self,
|
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||||
origin: &ServerName,
|
|
||||||
new_keys: ServerSigningKeys,
|
|
||||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||||
// Not atomic, but this is not critical
|
// Not atomic, but this is not critical
|
||||||
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
||||||
|
|
||||||
let mut keys = signingkeys
|
let mut keys = signingkeys.and_then(|keys| serde_json::from_slice(&keys).ok()).unwrap_or_else(|| {
|
||||||
.and_then(|keys| serde_json::from_slice(&keys).ok())
|
|
||||||
.unwrap_or_else(|| {
|
|
||||||
// Just insert "now", it doesn't matter
|
// Just insert "now", it doesn't matter
|
||||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||||
});
|
});
|
||||||
|
@ -265,31 +237,21 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let mut tree = keys.verify_keys;
|
let mut tree = keys.verify_keys;
|
||||||
tree.extend(
|
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key))));
|
||||||
keys.old_verify_keys
|
|
||||||
.into_iter()
|
|
||||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(tree)
|
Ok(tree)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||||
fn signing_keys_for(
|
/// for the server.
|
||||||
&self,
|
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||||
origin: &ServerName,
|
|
||||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
|
||||||
let signingkeys = self
|
let signingkeys = self
|
||||||
.server_signingkeys
|
.server_signingkeys
|
||||||
.get(origin.as_bytes())?
|
.get(origin.as_bytes())?
|
||||||
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
||||||
.map(|keys: ServerSigningKeys| {
|
.map(|keys: ServerSigningKeys| {
|
||||||
let mut tree = keys.verify_keys;
|
let mut tree = keys.verify_keys;
|
||||||
tree.extend(
|
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key))));
|
||||||
keys.old_verify_keys
|
|
||||||
.into_iter()
|
|
||||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
|
||||||
);
|
|
||||||
tree
|
tree
|
||||||
})
|
})
|
||||||
.unwrap_or_else(BTreeMap::new);
|
.unwrap_or_else(BTreeMap::new);
|
||||||
|
@ -299,8 +261,7 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|
||||||
|
|
||||||
fn database_version(&self) -> Result<u64> {
|
fn database_version(&self) -> Result<u64> {
|
||||||
self.global.get(b"version")?.map_or(Ok(0), |version| {
|
self.global.get(b"version")?.map_or(Ok(0), |version| {
|
||||||
utils::u64_from_bytes(&version)
|
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||||
.map_err(|_| Error::bad_database("Database version id is invalid."))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,35 +12,30 @@ use ruma::{
|
||||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||||
|
|
||||||
impl service::key_backups::Data for KeyValueDatabase {
|
impl service::key_backups::Data for KeyValueDatabase {
|
||||||
fn create_backup(
|
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
backup_metadata: &Raw<BackupAlgorithm>,
|
|
||||||
) -> Result<String> {
|
|
||||||
let version = services().globals.next_count()?.to_string();
|
let version = services().globals.next_count()?.to_string();
|
||||||
|
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(version.as_bytes());
|
key.extend_from_slice(version.as_bytes());
|
||||||
|
|
||||||
self.backupid_algorithm.insert(
|
self.backupid_algorithm.insert(
|
||||||
&key,
|
&key,
|
||||||
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
|
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
|
||||||
)?;
|
)?;
|
||||||
self.backupid_etag
|
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
|
||||||
Ok(version)
|
Ok(version)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(version.as_bytes());
|
key.extend_from_slice(version.as_bytes());
|
||||||
|
|
||||||
self.backupid_algorithm.remove(&key)?;
|
self.backupid_algorithm.remove(&key)?;
|
||||||
self.backupid_etag.remove(&key)?;
|
self.backupid_etag.remove(&key)?;
|
||||||
|
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
|
|
||||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||||
|
@ -49,33 +44,23 @@ impl service::key_backups::Data for KeyValueDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_backup(
|
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
backup_metadata: &Raw<BackupAlgorithm>,
|
|
||||||
) -> Result<String> {
|
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(version.as_bytes());
|
key.extend_from_slice(version.as_bytes());
|
||||||
|
|
||||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Tried to update nonexistent backup.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.backupid_algorithm
|
self.backupid_algorithm.insert(&key, backup_metadata.json().get().as_bytes())?;
|
||||||
.insert(&key, backup_metadata.json().get().as_bytes())?;
|
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||||
self.backupid_etag
|
|
||||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
|
||||||
Ok(version.to_owned())
|
Ok(version.to_owned())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
let mut last_possible_key = prefix.clone();
|
let mut last_possible_key = prefix.clone();
|
||||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||||
|
|
||||||
|
@ -84,22 +69,15 @@ impl service::key_backups::Data for KeyValueDatabase {
|
||||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||||
.next()
|
.next()
|
||||||
.map(|(key, _)| {
|
.map(|(key, _)| {
|
||||||
utils::string_from_bytes(
|
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||||
key.rsplit(|&b| b == 0xff)
|
|
||||||
.next()
|
|
||||||
.expect("rsplit always returns an element"),
|
|
||||||
)
|
|
||||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
|
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
|
||||||
})
|
})
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_latest_backup(
|
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
let mut last_possible_key = prefix.clone();
|
let mut last_possible_key = prefix.clone();
|
||||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||||
|
|
||||||
|
@ -109,17 +87,14 @@ impl service::key_backups::Data for KeyValueDatabase {
|
||||||
.next()
|
.next()
|
||||||
.map(|(key, value)| {
|
.map(|(key, value)| {
|
||||||
let version = utils::string_from_bytes(
|
let version = utils::string_from_bytes(
|
||||||
key.rsplit(|&b| b == 0xff)
|
key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"),
|
||||||
.next()
|
|
||||||
.expect("rsplit always returns an element"),
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
|
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
version,
|
version,
|
||||||
serde_json::from_slice(&value).map_err(|_| {
|
serde_json::from_slice(&value)
|
||||||
Error::bad_database("Algorithm in backupid_algorithm is invalid.")
|
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?,
|
||||||
})?,
|
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
.transpose()
|
.transpose()
|
||||||
|
@ -127,53 +102,41 @@ impl service::key_backups::Data for KeyValueDatabase {
|
||||||
|
|
||||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(version.as_bytes());
|
key.extend_from_slice(version.as_bytes());
|
||||||
|
|
||||||
self.backupid_algorithm
|
self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| {
|
||||||
.get(&key)?
|
|
||||||
.map_or(Ok(None), |bytes| {
|
|
||||||
serde_json::from_slice(&bytes)
|
serde_json::from_slice(&bytes)
|
||||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
|
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_key(
|
fn add_key(
|
||||||
&self,
|
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
session_id: &str,
|
|
||||||
key_data: &Raw<KeyBackupData>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(version.as_bytes());
|
key.extend_from_slice(version.as_bytes());
|
||||||
|
|
||||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Tried to update nonexistent backup.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.backupid_etag
|
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
|
||||||
|
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(room_id.as_bytes());
|
key.extend_from_slice(room_id.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(session_id.as_bytes());
|
key.extend_from_slice(session_id.as_bytes());
|
||||||
|
|
||||||
self.backupkeyid_backup
|
self.backupkeyid_backup.insert(&key, key_data.json().get().as_bytes())?;
|
||||||
.insert(&key, key_data.json().get().as_bytes())?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(version.as_bytes());
|
prefix.extend_from_slice(version.as_bytes());
|
||||||
|
|
||||||
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
|
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
|
||||||
|
@ -181,62 +144,45 @@ impl service::key_backups::Data for KeyValueDatabase {
|
||||||
|
|
||||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(version.as_bytes());
|
key.extend_from_slice(version.as_bytes());
|
||||||
|
|
||||||
Ok(utils::u64_from_bytes(
|
Ok(utils::u64_from_bytes(
|
||||||
&self
|
&self.backupid_etag.get(&key)?.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
|
||||||
.backupid_etag
|
|
||||||
.get(&key)?
|
|
||||||
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
|
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
|
||||||
.to_string())
|
.to_string())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_all(
|
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(version.as_bytes());
|
prefix.extend_from_slice(version.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
|
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
|
||||||
|
|
||||||
for result in self
|
for result in self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| {
|
||||||
.backupkeyid_backup
|
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||||
.scan_prefix(prefix)
|
|
||||||
.map(|(key, value)| {
|
|
||||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
|
||||||
|
|
||||||
let session_id =
|
let session_id = utils::string_from_bytes(
|
||||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
)
|
||||||
})?)
|
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let room_id = RoomId::parse(
|
let room_id = RoomId::parse(
|
||||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
utils::string_from_bytes(
|
||||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||||
})?)
|
)
|
||||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
|
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
|
||||||
)
|
)
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
|
||||||
Error::bad_database("backupkeyid_backup room_id is invalid room id.")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
let key_data = serde_json::from_slice(&value)
|
||||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok::<_, Error>((room_id, session_id, key_data))
|
Ok::<_, Error>((room_id, session_id, key_data))
|
||||||
})
|
}) {
|
||||||
{
|
|
||||||
let (room_id, session_id, key_data) = result?;
|
let (room_id, session_id, key_data) = result?;
|
||||||
rooms
|
rooms
|
||||||
.entry(room_id)
|
.entry(room_id)
|
||||||
|
@ -251,35 +197,28 @@ impl service::key_backups::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_room(
|
fn get_room(
|
||||||
&self,
|
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(version.as_bytes());
|
prefix.extend_from_slice(version.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(room_id.as_bytes());
|
prefix.extend_from_slice(room_id.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
Ok(self
|
Ok(self
|
||||||
.backupkeyid_backup
|
.backupkeyid_backup
|
||||||
.scan_prefix(prefix)
|
.scan_prefix(prefix)
|
||||||
.map(|(key, value)| {
|
.map(|(key, value)| {
|
||||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||||
|
|
||||||
let session_id =
|
let session_id = utils::string_from_bytes(
|
||||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
)
|
||||||
})?)
|
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
let key_data = serde_json::from_slice(&value)
|
||||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok::<_, Error>((session_id, key_data))
|
Ok::<_, Error>((session_id, key_data))
|
||||||
})
|
})
|
||||||
|
@ -288,35 +227,30 @@ impl service::key_backups::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_session(
|
fn get_session(
|
||||||
&self,
|
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(version.as_bytes());
|
key.extend_from_slice(version.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(room_id.as_bytes());
|
key.extend_from_slice(room_id.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(session_id.as_bytes());
|
key.extend_from_slice(session_id.as_bytes());
|
||||||
|
|
||||||
self.backupkeyid_backup
|
self.backupkeyid_backup
|
||||||
.get(&key)?
|
.get(&key)?
|
||||||
.map(|value| {
|
.map(|value| {
|
||||||
serde_json::from_slice(&value).map_err(|_| {
|
serde_json::from_slice(&value)
|
||||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(version.as_bytes());
|
key.extend_from_slice(version.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
|
|
||||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||||
|
@ -327,11 +261,11 @@ impl service::key_backups::Data for KeyValueDatabase {
|
||||||
|
|
||||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(version.as_bytes());
|
key.extend_from_slice(version.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(room_id.as_bytes());
|
key.extend_from_slice(room_id.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
|
|
||||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||||
|
@ -340,19 +274,13 @@ impl service::key_backups::Data for KeyValueDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn delete_room_key(
|
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(version.as_bytes());
|
key.extend_from_slice(version.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(room_id.as_bytes());
|
key.extend_from_slice(room_id.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(session_id.as_bytes());
|
key.extend_from_slice(session_id.as_bytes());
|
||||||
|
|
||||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||||
|
|
|
@ -9,31 +9,16 @@ use crate::{
|
||||||
|
|
||||||
impl service::media::Data for KeyValueDatabase {
|
impl service::media::Data for KeyValueDatabase {
|
||||||
fn create_file_metadata(
|
fn create_file_metadata(
|
||||||
&self,
|
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
|
||||||
mxc: String,
|
|
||||||
width: u32,
|
|
||||||
height: u32,
|
|
||||||
content_disposition: Option<&str>,
|
|
||||||
content_type: Option<&str>,
|
|
||||||
) -> Result<Vec<u8>> {
|
) -> Result<Vec<u8>> {
|
||||||
let mut key = mxc.as_bytes().to_vec();
|
let mut key = mxc.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(&width.to_be_bytes());
|
key.extend_from_slice(&width.to_be_bytes());
|
||||||
key.extend_from_slice(&height.to_be_bytes());
|
key.extend_from_slice(&height.to_be_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(
|
key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default());
|
||||||
content_disposition
|
key.push(0xFF);
|
||||||
.as_ref()
|
key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default());
|
||||||
.map(|f| f.as_bytes())
|
|
||||||
.unwrap_or_default(),
|
|
||||||
);
|
|
||||||
key.push(0xff);
|
|
||||||
key.extend_from_slice(
|
|
||||||
content_type
|
|
||||||
.as_ref()
|
|
||||||
.map(|c| c.as_bytes())
|
|
||||||
.unwrap_or_default(),
|
|
||||||
);
|
|
||||||
|
|
||||||
self.mediaid_file.insert(&key, &[])?;
|
self.mediaid_file.insert(&key, &[])?;
|
||||||
|
|
||||||
|
@ -44,7 +29,7 @@ impl service::media::Data for KeyValueDatabase {
|
||||||
debug!("MXC URI: {:?}", mxc);
|
debug!("MXC URI: {:?}", mxc);
|
||||||
|
|
||||||
let mut prefix = mxc.as_bytes().to_vec();
|
let mut prefix = mxc.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
debug!("MXC db prefix: {:?}", prefix);
|
debug!("MXC db prefix: {:?}", prefix);
|
||||||
|
|
||||||
|
@ -61,7 +46,7 @@ impl service::media::Data for KeyValueDatabase {
|
||||||
debug!("MXC URI: {:?}", mxc);
|
debug!("MXC URI: {:?}", mxc);
|
||||||
|
|
||||||
let mut prefix = mxc.as_bytes().to_vec();
|
let mut prefix = mxc.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let mut keys: Vec<Vec<u8>> = vec![];
|
let mut keys: Vec<Vec<u8>> = vec![];
|
||||||
|
|
||||||
|
@ -81,16 +66,13 @@ impl service::media::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn search_file_metadata(
|
fn search_file_metadata(
|
||||||
&self,
|
&self, mxc: String, width: u32, height: u32,
|
||||||
mxc: String,
|
|
||||||
width: u32,
|
|
||||||
height: u32,
|
|
||||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||||
let mut prefix = mxc.as_bytes().to_vec();
|
let mut prefix = mxc.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(&width.to_be_bytes());
|
prefix.extend_from_slice(&width.to_be_bytes());
|
||||||
prefix.extend_from_slice(&height.to_be_bytes());
|
prefix.extend_from_slice(&height.to_be_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let (key, _) = self
|
let (key, _) = self
|
||||||
.mediaid_file
|
.mediaid_file
|
||||||
|
@ -98,34 +80,32 @@ impl service::media::Data for KeyValueDatabase {
|
||||||
.next()
|
.next()
|
||||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
|
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
|
||||||
|
|
||||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||||
|
|
||||||
let content_type = parts
|
let content_type = parts
|
||||||
.next()
|
.next()
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
utils::string_from_bytes(bytes).map_err(|_| {
|
utils::string_from_bytes(bytes)
|
||||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
|
||||||
let content_disposition_bytes = parts
|
let content_disposition_bytes =
|
||||||
.next()
|
parts.next().ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
|
||||||
|
|
||||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(
|
Some(
|
||||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
utils::string_from_bytes(content_disposition_bytes)
|
||||||
Error::bad_database("Content Disposition in mediaid_file is invalid unicode.")
|
.map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?,
|
||||||
})?,
|
|
||||||
)
|
)
|
||||||
};
|
};
|
||||||
Ok((content_disposition, content_type, key))
|
Ok((content_disposition, content_type, key))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Gets all the media keys in our database (this includes all the metadata associated with it such as width, height, content-type, etc)
|
/// Gets all the media keys in our database (this includes all the metadata
|
||||||
|
/// associated with it such as width, height, content-type, etc)
|
||||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
|
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
|
||||||
let mut keys: Vec<Vec<u8>> = vec![];
|
let mut keys: Vec<Vec<u8>> = vec![];
|
||||||
|
|
||||||
|
@ -136,44 +116,22 @@ impl service::media::Data for KeyValueDatabase {
|
||||||
Ok(keys)
|
Ok(keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove_url_preview(&self, url: &str) -> Result<()> {
|
fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
|
||||||
self.url_previews.remove(url.as_bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_url_preview(
|
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> {
|
||||||
&self,
|
|
||||||
url: &str,
|
|
||||||
data: &UrlPreviewData,
|
|
||||||
timestamp: std::time::Duration,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut value = Vec::<u8>::new();
|
let mut value = Vec::<u8>::new();
|
||||||
value.extend_from_slice(×tamp.as_secs().to_be_bytes());
|
value.extend_from_slice(×tamp.as_secs().to_be_bytes());
|
||||||
value.push(0xff);
|
value.push(0xFF);
|
||||||
value.extend_from_slice(
|
value.extend_from_slice(data.title.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||||
data.title
|
value.push(0xFF);
|
||||||
.as_ref()
|
value.extend_from_slice(data.description.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||||
.map(std::string::String::as_bytes)
|
value.push(0xFF);
|
||||||
.unwrap_or_default(),
|
value.extend_from_slice(data.image.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||||
);
|
value.push(0xFF);
|
||||||
value.push(0xff);
|
|
||||||
value.extend_from_slice(
|
|
||||||
data.description
|
|
||||||
.as_ref()
|
|
||||||
.map(std::string::String::as_bytes)
|
|
||||||
.unwrap_or_default(),
|
|
||||||
);
|
|
||||||
value.push(0xff);
|
|
||||||
value.extend_from_slice(
|
|
||||||
data.image
|
|
||||||
.as_ref()
|
|
||||||
.map(std::string::String::as_bytes)
|
|
||||||
.unwrap_or_default(),
|
|
||||||
);
|
|
||||||
value.push(0xff);
|
|
||||||
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
|
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
|
||||||
value.push(0xff);
|
value.push(0xFF);
|
||||||
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
|
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
|
||||||
value.push(0xff);
|
value.push(0xFF);
|
||||||
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
|
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
|
||||||
|
|
||||||
self.url_previews.insert(url.as_bytes(), &value)
|
self.url_previews.insert(url.as_bytes(), &value)
|
||||||
|
@ -182,54 +140,33 @@ impl service::media::Data for KeyValueDatabase {
|
||||||
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
|
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
|
||||||
let values = self.url_previews.get(url.as_bytes()).ok()??;
|
let values = self.url_previews.get(url.as_bytes()).ok()??;
|
||||||
|
|
||||||
let mut values = values.split(|&b| b == 0xff);
|
let mut values = values.split(|&b| b == 0xFF);
|
||||||
|
|
||||||
let _ts = match values
|
let _ts = match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||||
.next()
|
|
||||||
.map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array")))
|
|
||||||
{
|
|
||||||
Some(0) => None,
|
Some(0) => None,
|
||||||
x => x,
|
x => x,
|
||||||
};
|
};
|
||||||
let title = match values
|
let title = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||||
.next()
|
|
||||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
|
||||||
{
|
|
||||||
Some(s) if s.is_empty() => None,
|
Some(s) if s.is_empty() => None,
|
||||||
x => x,
|
x => x,
|
||||||
};
|
};
|
||||||
let description = match values
|
let description = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||||
.next()
|
|
||||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
|
||||||
{
|
|
||||||
Some(s) if s.is_empty() => None,
|
Some(s) if s.is_empty() => None,
|
||||||
x => x,
|
x => x,
|
||||||
};
|
};
|
||||||
let image = match values
|
let image = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||||
.next()
|
|
||||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
|
||||||
{
|
|
||||||
Some(s) if s.is_empty() => None,
|
Some(s) if s.is_empty() => None,
|
||||||
x => x,
|
x => x,
|
||||||
};
|
};
|
||||||
let image_size = match values
|
let image_size = match values.next().map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||||
.next()
|
|
||||||
.map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array")))
|
|
||||||
{
|
|
||||||
Some(0) => None,
|
Some(0) => None,
|
||||||
x => x,
|
x => x,
|
||||||
};
|
};
|
||||||
let image_width = match values
|
let image_width = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||||
.next()
|
|
||||||
.map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array")))
|
|
||||||
{
|
|
||||||
Some(0) => None,
|
Some(0) => None,
|
||||||
x => x,
|
x => x,
|
||||||
};
|
};
|
||||||
let image_height = match values
|
let image_height = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||||
.next()
|
|
||||||
.map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array")))
|
|
||||||
{
|
|
||||||
Some(0) => None,
|
Some(0) => None,
|
||||||
x => x,
|
x => x,
|
||||||
};
|
};
|
||||||
|
|
|
@ -10,66 +10,50 @@ impl service::pusher::Data for KeyValueDatabase {
|
||||||
match &pusher {
|
match &pusher {
|
||||||
set_pusher::v3::PusherAction::Post(data) => {
|
set_pusher::v3::PusherAction::Post(data) => {
|
||||||
let mut key = sender.as_bytes().to_vec();
|
let mut key = sender.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
|
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
|
||||||
self.senderkey_pusher.insert(
|
self.senderkey_pusher
|
||||||
&key,
|
.insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?;
|
||||||
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"),
|
|
||||||
)?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
},
|
||||||
set_pusher::v3::PusherAction::Delete(ids) => {
|
set_pusher::v3::PusherAction::Delete(ids) => {
|
||||||
let mut key = sender.as_bytes().to_vec();
|
let mut key = sender.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(ids.pushkey.as_bytes());
|
key.extend_from_slice(ids.pushkey.as_bytes());
|
||||||
self.senderkey_pusher
|
self.senderkey_pusher.remove(&key).map(|_| ()).map_err(Into::into)
|
||||||
.remove(&key)
|
},
|
||||||
.map(|_| ())
|
|
||||||
.map_err(Into::into)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
|
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
|
||||||
let mut senderkey = sender.as_bytes().to_vec();
|
let mut senderkey = sender.as_bytes().to_vec();
|
||||||
senderkey.push(0xff);
|
senderkey.push(0xFF);
|
||||||
senderkey.extend_from_slice(pushkey.as_bytes());
|
senderkey.extend_from_slice(pushkey.as_bytes());
|
||||||
|
|
||||||
self.senderkey_pusher
|
self.senderkey_pusher
|
||||||
.get(&senderkey)?
|
.get(&senderkey)?
|
||||||
.map(|push| {
|
.map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
|
||||||
serde_json::from_slice(&push)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
|
||||||
})
|
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
|
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
|
||||||
let mut prefix = sender.as_bytes().to_vec();
|
let mut prefix = sender.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
self.senderkey_pusher
|
self.senderkey_pusher
|
||||||
.scan_prefix(prefix)
|
.scan_prefix(prefix)
|
||||||
.map(|(_, push)| {
|
.map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
|
||||||
serde_json::from_slice(&push)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
|
||||||
})
|
|
||||||
.collect()
|
.collect()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_pushkeys<'a>(
|
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
|
||||||
&'a self,
|
|
||||||
sender: &UserId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
|
|
||||||
let mut prefix = sender.as_bytes().to_vec();
|
let mut prefix = sender.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
|
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
|
||||||
let mut parts = k.splitn(2, |&b| b == 0xff);
|
let mut parts = k.splitn(2, |&b| b == 0xFF);
|
||||||
let _senderkey = parts.next();
|
let _senderkey = parts.next();
|
||||||
let push_key = parts
|
let push_key = parts.next().ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
|
||||||
.next()
|
|
||||||
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
|
|
||||||
let push_key_string = utils::string_from_bytes(push_key)
|
let push_key_string = utils::string_from_bytes(push_key)
|
||||||
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
|
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
|
||||||
|
|
||||||
|
|
|
@ -4,10 +4,9 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}
|
||||||
|
|
||||||
impl service::rooms::alias::Data for KeyValueDatabase {
|
impl service::rooms::alias::Data for KeyValueDatabase {
|
||||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
||||||
self.alias_roomid
|
self.alias_roomid.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
|
||||||
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
|
|
||||||
let mut aliasid = room_id.as_bytes().to_vec();
|
let mut aliasid = room_id.as_bytes().to_vec();
|
||||||
aliasid.push(0xff);
|
aliasid.push(0xFF);
|
||||||
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||||
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
|
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -16,17 +15,14 @@ impl service::rooms::alias::Data for KeyValueDatabase {
|
||||||
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
||||||
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
|
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
|
||||||
let mut prefix = room_id;
|
let mut prefix = room_id;
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
|
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
|
||||||
self.aliasid_alias.remove(&key)?;
|
self.aliasid_alias.remove(&key)?;
|
||||||
}
|
}
|
||||||
self.alias_roomid.remove(alias.alias().as_bytes())?;
|
self.alias_roomid.remove(alias.alias().as_bytes())?;
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
|
||||||
ErrorKind::NotFound,
|
|
||||||
"Alias does not exist.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -35,20 +31,20 @@ impl service::rooms::alias::Data for KeyValueDatabase {
|
||||||
self.alias_roomid
|
self.alias_roomid
|
||||||
.get(alias.alias().as_bytes())?
|
.get(alias.alias().as_bytes())?
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
RoomId::parse(
|
||||||
Error::bad_database("Room ID in alias_roomid is invalid unicode.")
|
utils::string_from_bytes(&bytes)
|
||||||
})?)
|
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
|
||||||
|
)
|
||||||
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
|
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
|
||||||
})
|
})
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn local_aliases_for_room<'a>(
|
fn local_aliases_for_room<'a>(
|
||||||
&'a self,
|
&'a self, room_id: &RoomId,
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
|
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
|
||||||
utils::string_from_bytes(&bytes)
|
utils::string_from_bytes(&bytes)
|
||||||
|
@ -58,27 +54,17 @@ impl service::rooms::alias::Data for KeyValueDatabase {
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn all_local_aliases<'a>(
|
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||||
&'a self,
|
Box::new(self.alias_roomid.iter().map(|(room_alias_bytes, room_id_bytes)| {
|
||||||
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
|
||||||
Box::new(
|
|
||||||
self.alias_roomid
|
|
||||||
.iter()
|
|
||||||
.map(|(room_alias_bytes, room_id_bytes)| {
|
|
||||||
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
|
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
|
||||||
Error::bad_database("Invalid alias bytes in aliasid_alias.")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let room_id = utils::string_from_bytes(&room_id_bytes)
|
let room_id = utils::string_from_bytes(&room_id_bytes)
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
|
||||||
Error::bad_database("Invalid room_id bytes in aliasid_alias.")
|
|
||||||
})?
|
|
||||||
.try_into()
|
.try_into()
|
||||||
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
|
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
|
||||||
|
|
||||||
Ok((room_id, room_alias_localpart))
|
Ok((room_id, room_alias_localpart))
|
||||||
}),
|
}))
|
||||||
)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -12,10 +12,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
|
||||||
// We only save auth chains for single events in the db
|
// We only save auth chains for single events in the db
|
||||||
if key.len() == 1 {
|
if key.len() == 1 {
|
||||||
// Check DB cache
|
// Check DB cache
|
||||||
let chain = self
|
let chain = self.shorteventid_authchain.get(&key[0].to_be_bytes())?.map(|chain| {
|
||||||
.shorteventid_authchain
|
|
||||||
.get(&key[0].to_be_bytes())?
|
|
||||||
.map(|chain| {
|
|
||||||
chain
|
chain
|
||||||
.chunks_exact(size_of::<u64>())
|
.chunks_exact(size_of::<u64>())
|
||||||
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
|
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
|
||||||
|
@ -26,10 +23,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
|
||||||
let chain = Arc::new(chain);
|
let chain = Arc::new(chain);
|
||||||
|
|
||||||
// Cache in RAM
|
// Cache in RAM
|
||||||
self.auth_chain_cache
|
self.auth_chain_cache.lock().unwrap().insert(vec![key[0]], Arc::clone(&chain));
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(vec![key[0]], Arc::clone(&chain));
|
|
||||||
|
|
||||||
return Ok(Some(chain));
|
return Ok(Some(chain));
|
||||||
}
|
}
|
||||||
|
@ -43,18 +37,12 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
|
||||||
if key.len() == 1 {
|
if key.len() == 1 {
|
||||||
self.shorteventid_authchain.insert(
|
self.shorteventid_authchain.insert(
|
||||||
&key[0].to_be_bytes(),
|
&key[0].to_be_bytes(),
|
||||||
&auth_chain
|
&auth_chain.iter().flat_map(|s| s.to_be_bytes().to_vec()).collect::<Vec<u8>>(),
|
||||||
.iter()
|
|
||||||
.flat_map(|s| s.to_be_bytes().to_vec())
|
|
||||||
.collect::<Vec<u8>>(),
|
|
||||||
)?;
|
)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cache in RAM
|
// Cache in RAM
|
||||||
self.auth_chain_cache
|
self.auth_chain_cache.lock().unwrap().insert(key, auth_chain);
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(key, auth_chain);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,13 +3,9 @@ use ruma::{OwnedRoomId, RoomId};
|
||||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||||
|
|
||||||
impl service::rooms::directory::Data for KeyValueDatabase {
|
impl service::rooms::directory::Data for KeyValueDatabase {
|
||||||
fn set_public(&self, room_id: &RoomId) -> Result<()> {
|
fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) }
|
||||||
self.publicroomids.insert(room_id.as_bytes(), &[])
|
|
||||||
}
|
|
||||||
|
|
||||||
fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
|
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) }
|
||||||
self.publicroomids.remove(room_id.as_bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
||||||
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
|
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
|
||||||
|
@ -18,9 +14,8 @@ impl service::rooms::directory::Data for KeyValueDatabase {
|
||||||
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||||
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
|
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
|
||||||
RoomId::parse(
|
RoomId::parse(
|
||||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
utils::string_from_bytes(&bytes)
|
||||||
Error::bad_database("Room ID in publicroomids is invalid unicode.")
|
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
|
||||||
})?,
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
|
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
|
||||||
}))
|
}))
|
||||||
|
|
|
@ -1,8 +1,6 @@
|
||||||
use std::time::Duration;
|
use std::time::Duration;
|
||||||
|
|
||||||
use ruma::{
|
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId};
|
||||||
events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId,
|
|
||||||
};
|
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
|
@ -63,18 +61,11 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
||||||
presence.last_count = count;
|
presence.last_count = count;
|
||||||
|
|
||||||
presence
|
presence
|
||||||
}
|
},
|
||||||
None => Presence::new(
|
None => Presence::new(new_state.clone(), new_state == PresenceState::Online, now, count, None),
|
||||||
new_state.clone(),
|
|
||||||
new_state == PresenceState::Online,
|
|
||||||
now,
|
|
||||||
count,
|
|
||||||
None,
|
|
||||||
),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
self.roomuserid_presence
|
self.roomuserid_presence.insert(&key, &new_presence.to_json_bytes()?)?;
|
||||||
.insert(&key, &new_presence.to_json_bytes()?)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let timeout = match new_state {
|
let timeout = match new_state {
|
||||||
|
@ -82,22 +73,15 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
||||||
_ => services().globals.config.presence_offline_timeout_s,
|
_ => services().globals.config.presence_offline_timeout_s,
|
||||||
};
|
};
|
||||||
|
|
||||||
self.presence_timer_sender
|
self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| {
|
||||||
.send((user_id.to_owned(), Duration::from_secs(timeout)))
|
|
||||||
.map_err(|e| {
|
|
||||||
error!("Failed to add presence timer: {}", e);
|
error!("Failed to add presence timer: {}", e);
|
||||||
Error::bad_database("Failed to add presence timer")
|
Error::bad_database("Failed to add presence timer")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_presence(
|
fn set_presence(
|
||||||
&self,
|
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
|
||||||
room_id: &RoomId,
|
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||||
user_id: &UserId,
|
|
||||||
presence_state: PresenceState,
|
|
||||||
currently_active: Option<bool>,
|
|
||||||
last_active_ago: Option<UInt>,
|
|
||||||
status_msg: Option<String>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let now = utils::millis_since_unix_epoch();
|
let now = utils::millis_since_unix_epoch();
|
||||||
let last_active_ts = match last_active_ago {
|
let last_active_ts = match last_active_ago {
|
||||||
|
@ -120,15 +104,12 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
||||||
_ => services().globals.config.presence_offline_timeout_s,
|
_ => services().globals.config.presence_offline_timeout_s,
|
||||||
};
|
};
|
||||||
|
|
||||||
self.presence_timer_sender
|
self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| {
|
||||||
.send((user_id.to_owned(), Duration::from_secs(timeout)))
|
|
||||||
.map_err(|e| {
|
|
||||||
error!("Failed to add presence timer: {}", e);
|
error!("Failed to add presence timer: {}", e);
|
||||||
Error::bad_database("Failed to add presence timer")
|
Error::bad_database("Failed to add presence timer")
|
||||||
})?;
|
})?;
|
||||||
|
|
||||||
self.roomuserid_presence
|
self.roomuserid_presence.insert(&key, &presence.to_json_bytes()?)?;
|
||||||
.insert(&key, &presence.to_json_bytes()?)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -144,29 +125,25 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn presence_since<'a>(
|
fn presence_since<'a>(
|
||||||
&'a self,
|
&'a self, room_id: &RoomId, since: u64,
|
||||||
room_id: &RoomId,
|
|
||||||
since: u64,
|
|
||||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a> {
|
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a> {
|
||||||
let prefix = [room_id.as_bytes(), &[0xff]].concat();
|
let prefix = [room_id.as_bytes(), &[0xFF]].concat();
|
||||||
|
|
||||||
Box::new(
|
Box::new(
|
||||||
self.roomuserid_presence
|
self.roomuserid_presence
|
||||||
.scan_prefix(prefix)
|
.scan_prefix(prefix)
|
||||||
.flat_map(
|
.flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> {
|
||||||
|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> {
|
|
||||||
let user_id = user_id_from_bytes(
|
let user_id = user_id_from_bytes(
|
||||||
key.rsplit(|byte| *byte == 0xff).next().ok_or_else(|| {
|
key.rsplit(|byte| *byte == 0xFF)
|
||||||
Error::bad_database("No UserID bytes in presence key")
|
.next()
|
||||||
})?,
|
.ok_or_else(|| Error::bad_database("No UserID bytes in presence key"))?,
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||||
let presence_event = presence.to_presence_event(&user_id)?;
|
let presence_event = presence.to_presence_event(&user_id)?;
|
||||||
|
|
||||||
Ok((user_id, presence.last_count, presence_event))
|
Ok((user_id, presence.last_count, presence_event))
|
||||||
},
|
})
|
||||||
)
|
|
||||||
.filter(move |(_, count, _)| *count > since),
|
.filter(move |(_, count, _)| *count > since),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -174,5 +151,5 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
||||||
|
|
||||||
#[inline]
|
#[inline]
|
||||||
fn presence_key(room_id: &RoomId, user_id: &UserId) -> Vec<u8> {
|
fn presence_key(room_id: &RoomId, user_id: &UserId) -> Vec<u8> {
|
||||||
[room_id.as_bytes(), &[0xff], user_id.as_bytes()].concat()
|
[room_id.as_bytes(), &[0xFF], user_id.as_bytes()].concat()
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,20 +1,13 @@
|
||||||
use std::mem;
|
use std::mem;
|
||||||
|
|
||||||
use ruma::{
|
use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId};
|
||||||
events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId,
|
|
||||||
};
|
|
||||||
|
|
||||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||||
|
|
||||||
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
||||||
fn readreceipt_update(
|
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event: ReceiptEvent,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let mut last_possible_key = prefix.clone();
|
let mut last_possible_key = prefix.clone();
|
||||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||||
|
@ -25,19 +18,15 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
||||||
.iter_from(&last_possible_key, true)
|
.iter_from(&last_possible_key, true)
|
||||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||||
.find(|(key, _)| {
|
.find(|(key, _)| {
|
||||||
key.rsplit(|&b| b == 0xff)
|
key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element") == user_id.as_bytes()
|
||||||
.next()
|
}) {
|
||||||
.expect("rsplit always returns an element")
|
|
||||||
== user_id.as_bytes()
|
|
||||||
})
|
|
||||||
{
|
|
||||||
// This is the old room_latest
|
// This is the old room_latest
|
||||||
self.readreceiptid_readreceipt.remove(&old)?;
|
self.readreceiptid_readreceipt.remove(&old)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut room_latest_id = prefix;
|
let mut room_latest_id = prefix;
|
||||||
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||||
room_latest_id.push(0xff);
|
room_latest_id.push(0xFF);
|
||||||
room_latest_id.extend_from_slice(user_id.as_bytes());
|
room_latest_id.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
self.readreceiptid_readreceipt.insert(
|
self.readreceiptid_readreceipt.insert(
|
||||||
|
@ -49,20 +38,10 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn readreceipts_since<'a>(
|
fn readreceipts_since<'a>(
|
||||||
&'a self,
|
&'a self, room_id: &RoomId, since: u64,
|
||||||
room_id: &RoomId,
|
) -> Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a> {
|
||||||
since: u64,
|
|
||||||
) -> Box<
|
|
||||||
dyn Iterator<
|
|
||||||
Item = Result<(
|
|
||||||
OwnedUserId,
|
|
||||||
u64,
|
|
||||||
Raw<ruma::events::AnySyncEphemeralRoomEvent>,
|
|
||||||
)>,
|
|
||||||
> + 'a,
|
|
||||||
> {
|
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
let prefix2 = prefix.clone();
|
let prefix2 = prefix.clone();
|
||||||
|
|
||||||
let mut first_possible_edu = prefix.clone();
|
let mut first_possible_edu = prefix.clone();
|
||||||
|
@ -73,33 +52,22 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
||||||
.iter_from(&first_possible_edu, false)
|
.iter_from(&first_possible_edu, false)
|
||||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||||
.map(move |(k, v)| {
|
.map(move |(k, v)| {
|
||||||
let count = utils::u64_from_bytes(
|
let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
|
||||||
&k[prefix.len()..prefix.len() + mem::size_of::<u64>()],
|
|
||||||
)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
|
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
|
||||||
let user_id = UserId::parse(
|
let user_id = UserId::parse(
|
||||||
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
|
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
|
||||||
Error::bad_database("Invalid readreceiptid userid bytes in db.")
|
|
||||||
})?,
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
|
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
|
||||||
|
|
||||||
let mut json =
|
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
|
||||||
serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| {
|
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
|
||||||
Error::bad_database(
|
|
||||||
"Read receipt in roomlatestid_roomlatest is invalid json.",
|
|
||||||
)
|
|
||||||
})?;
|
|
||||||
json.remove("room_id");
|
json.remove("room_id");
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
user_id,
|
user_id,
|
||||||
count,
|
count,
|
||||||
Raw::from_json(
|
Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
|
||||||
serde_json::value::to_raw_value(&json)
|
|
||||||
.expect("json is valid raw value"),
|
|
||||||
),
|
|
||||||
))
|
))
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
|
@ -107,42 +75,37 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
||||||
|
|
||||||
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
|
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
|
||||||
let mut key = room_id.as_bytes().to_vec();
|
let mut key = room_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(user_id.as_bytes());
|
key.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
self.roomuserid_privateread
|
self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?;
|
||||||
.insert(&key, &count.to_be_bytes())?;
|
|
||||||
|
|
||||||
self.roomuserid_lastprivatereadupdate
|
self.roomuserid_lastprivatereadupdate.insert(&key, &services().globals.next_count()?.to_be_bytes())
|
||||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||||
let mut key = room_id.as_bytes().to_vec();
|
let mut key = room_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(user_id.as_bytes());
|
key.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
self.roomuserid_privateread
|
self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| {
|
||||||
.get(&key)?
|
Ok(Some(
|
||||||
.map_or(Ok(None), |v| {
|
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
|
||||||
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
|
))
|
||||||
Error::bad_database("Invalid private read marker bytes")
|
|
||||||
})?))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||||
let mut key = room_id.as_bytes().to_vec();
|
let mut key = room_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(user_id.as_bytes());
|
key.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
Ok(self
|
Ok(self
|
||||||
.roomuserid_lastprivatereadupdate
|
.roomuserid_lastprivatereadupdate
|
||||||
.get(&key)?
|
.get(&key)?
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
utils::u64_from_bytes(&bytes)
|
||||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.transpose()?
|
.transpose()?
|
||||||
.unwrap_or(0))
|
.unwrap_or(0))
|
||||||
|
|
|
@ -7,47 +7,38 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}
|
||||||
impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
||||||
fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
|
fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let count = services().globals.next_count()?.to_be_bytes();
|
let count = services().globals.next_count()?.to_be_bytes();
|
||||||
|
|
||||||
let mut room_typing_id = prefix;
|
let mut room_typing_id = prefix;
|
||||||
room_typing_id.extend_from_slice(&timeout.to_be_bytes());
|
room_typing_id.extend_from_slice(&timeout.to_be_bytes());
|
||||||
room_typing_id.push(0xff);
|
room_typing_id.push(0xFF);
|
||||||
room_typing_id.extend_from_slice(&count);
|
room_typing_id.extend_from_slice(&count);
|
||||||
|
|
||||||
self.typingid_userid
|
self.typingid_userid.insert(&room_typing_id, user_id.as_bytes())?;
|
||||||
.insert(&room_typing_id, user_id.as_bytes())?;
|
|
||||||
|
|
||||||
self.roomid_lasttypingupdate
|
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &count)?;
|
||||||
.insert(room_id.as_bytes(), &count)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let user_id = user_id.to_string();
|
let user_id = user_id.to_string();
|
||||||
|
|
||||||
let mut found_outdated = false;
|
let mut found_outdated = false;
|
||||||
|
|
||||||
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
||||||
for outdated_edu in self
|
for outdated_edu in self.typingid_userid.scan_prefix(prefix).filter(|(_, v)| &**v == user_id.as_bytes()) {
|
||||||
.typingid_userid
|
|
||||||
.scan_prefix(prefix)
|
|
||||||
.filter(|(_, v)| &**v == user_id.as_bytes())
|
|
||||||
{
|
|
||||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||||
found_outdated = true;
|
found_outdated = true;
|
||||||
}
|
}
|
||||||
|
|
||||||
if found_outdated {
|
if found_outdated {
|
||||||
self.roomid_lasttypingupdate.insert(
|
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||||
room_id.as_bytes(),
|
|
||||||
&services().globals.next_count()?.to_be_bytes(),
|
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -55,7 +46,7 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
||||||
|
|
||||||
fn typings_maintain(&self, room_id: &RoomId) -> Result<()> {
|
fn typings_maintain(&self, room_id: &RoomId) -> Result<()> {
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let current_timestamp = utils::millis_since_unix_epoch();
|
let current_timestamp = utils::millis_since_unix_epoch();
|
||||||
|
|
||||||
|
@ -69,9 +60,9 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
||||||
Ok::<_, Error>((
|
Ok::<_, Error>((
|
||||||
key.clone(),
|
key.clone(),
|
||||||
utils::u64_from_bytes(
|
utils::u64_from_bytes(
|
||||||
&key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| {
|
&key.splitn(2, |&b| b == 0xFF)
|
||||||
Error::bad_database("RoomTyping has invalid timestamp or delimiters.")
|
.nth(1)
|
||||||
})?[0..mem::size_of::<u64>()],
|
.ok_or_else(|| Error::bad_database("RoomTyping has invalid timestamp or delimiters."))?[0..mem::size_of::<u64>()],
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?,
|
.map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?,
|
||||||
))
|
))
|
||||||
|
@ -85,10 +76,7 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
if found_outdated {
|
if found_outdated {
|
||||||
self.roomid_lasttypingupdate.insert(
|
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||||
room_id.as_bytes(),
|
|
||||||
&services().globals.next_count()?.to_be_bytes(),
|
|
||||||
)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
|
@ -99,9 +87,8 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
||||||
.roomid_lasttypingupdate
|
.roomid_lasttypingupdate
|
||||||
.get(room_id.as_bytes())?
|
.get(room_id.as_bytes())?
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
utils::u64_from_bytes(&bytes)
|
||||||
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
|
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.transpose()?
|
.transpose()?
|
||||||
.unwrap_or(0))
|
.unwrap_or(0))
|
||||||
|
@ -109,14 +96,15 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
||||||
|
|
||||||
fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>> {
|
fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>> {
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let mut user_ids = HashSet::new();
|
let mut user_ids = HashSet::new();
|
||||||
|
|
||||||
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
|
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
|
||||||
let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| {
|
let user_id = UserId::parse(
|
||||||
Error::bad_database("User ID in typingid_userid is invalid unicode.")
|
utils::string_from_bytes(&user_id)
|
||||||
})?)
|
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid unicode."))?,
|
||||||
|
)
|
||||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
|
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
|
||||||
|
|
||||||
user_ids.insert(user_id);
|
user_ids.insert(user_id);
|
||||||
|
|
|
@ -4,35 +4,28 @@ use crate::{database::KeyValueDatabase, service, Result};
|
||||||
|
|
||||||
impl service::rooms::lazy_loading::Data for KeyValueDatabase {
|
impl service::rooms::lazy_loading::Data for KeyValueDatabase {
|
||||||
fn lazy_load_was_sent_before(
|
fn lazy_load_was_sent_before(
|
||||||
&self,
|
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
room_id: &RoomId,
|
|
||||||
ll_user: &UserId,
|
|
||||||
) -> Result<bool> {
|
) -> Result<bool> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(device_id.as_bytes());
|
key.extend_from_slice(device_id.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(room_id.as_bytes());
|
key.extend_from_slice(room_id.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(ll_user.as_bytes());
|
key.extend_from_slice(ll_user.as_bytes());
|
||||||
Ok(self.lazyloadedids.get(&key)?.is_some())
|
Ok(self.lazyloadedids.get(&key)?.is_some())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lazy_load_confirm_delivery(
|
fn lazy_load_confirm_delivery(
|
||||||
&self,
|
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
room_id: &RoomId,
|
|
||||||
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
|
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(device_id.as_bytes());
|
prefix.extend_from_slice(device_id.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(room_id.as_bytes());
|
prefix.extend_from_slice(room_id.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
for ll_id in confirmed_user_ids {
|
for ll_id in confirmed_user_ids {
|
||||||
let mut key = prefix.clone();
|
let mut key = prefix.clone();
|
||||||
|
@ -43,18 +36,13 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn lazy_load_reset(
|
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(device_id.as_bytes());
|
prefix.extend_from_slice(device_id.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(room_id.as_bytes());
|
prefix.extend_from_slice(room_id.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
|
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
|
||||||
self.lazyloadedids.remove(&key)?;
|
self.lazyloadedids.remove(&key)?;
|
||||||
|
|
|
@ -11,20 +11,14 @@ impl service::rooms::metadata::Data for KeyValueDatabase {
|
||||||
};
|
};
|
||||||
|
|
||||||
// Look for PDUs in that room.
|
// Look for PDUs in that room.
|
||||||
Ok(self
|
Ok(self.pduid_pdu.iter_from(&prefix, false).next().filter(|(k, _)| k.starts_with(&prefix)).is_some())
|
||||||
.pduid_pdu
|
|
||||||
.iter_from(&prefix, false)
|
|
||||||
.next()
|
|
||||||
.filter(|(k, _)| k.starts_with(&prefix))
|
|
||||||
.is_some())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||||
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
|
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
|
||||||
RoomId::parse(
|
RoomId::parse(
|
||||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
utils::string_from_bytes(&bytes)
|
||||||
Error::bad_database("Room ID in publicroomids is invalid unicode.")
|
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
|
||||||
})?,
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
|
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
|
||||||
}))
|
}))
|
||||||
|
@ -44,9 +38,7 @@ impl service::rooms::metadata::Data for KeyValueDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_banned(&self, room_id: &RoomId) -> Result<bool> {
|
fn is_banned(&self, room_id: &RoomId) -> Result<bool> { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) }
|
||||||
Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
|
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
|
||||||
if banned {
|
if banned {
|
||||||
|
|
|
@ -4,17 +4,13 @@ use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result};
|
||||||
|
|
||||||
impl service::rooms::outlier::Data for KeyValueDatabase {
|
impl service::rooms::outlier::Data for KeyValueDatabase {
|
||||||
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||||
self.eventid_outlierpdu
|
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| {
|
||||||
.get(event_id.as_bytes())?
|
|
||||||
.map_or(Ok(None), |pdu| {
|
|
||||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||||
self.eventid_outlierpdu
|
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| {
|
||||||
.get(event_id.as_bytes())?
|
|
||||||
.map_or(Ok(None), |pdu| {
|
|
||||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -17,11 +17,7 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn relations_until<'a>(
|
fn relations_until<'a>(
|
||||||
&'a self,
|
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
|
||||||
user_id: &'a UserId,
|
|
||||||
shortroomid: u64,
|
|
||||||
target: u64,
|
|
||||||
until: PduCount,
|
|
||||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||||
let prefix = target.to_be_bytes().to_vec();
|
let prefix = target.to_be_bytes().to_vec();
|
||||||
let mut current = prefix.clone();
|
let mut current = prefix.clone();
|
||||||
|
@ -31,15 +27,13 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
||||||
PduCount::Backfilled(x) => {
|
PduCount::Backfilled(x) => {
|
||||||
current.extend_from_slice(&0_u64.to_be_bytes());
|
current.extend_from_slice(&0_u64.to_be_bytes());
|
||||||
u64::MAX - x - 1
|
u64::MAX - x - 1
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
current.extend_from_slice(&count_raw.to_be_bytes());
|
current.extend_from_slice(&count_raw.to_be_bytes());
|
||||||
|
|
||||||
Ok(Box::new(
|
Ok(Box::new(
|
||||||
self.tofrom_relation
|
self.tofrom_relation.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||||
.iter_from(¤t, true)
|
move |(tofrom, _data)| {
|
||||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
|
||||||
.map(move |(tofrom, _data)| {
|
|
||||||
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
|
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
|
||||||
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
|
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
|
||||||
|
|
||||||
|
@ -55,7 +49,8 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
||||||
pdu.remove_transaction_id()?;
|
pdu.remove_transaction_id()?;
|
||||||
}
|
}
|
||||||
Ok((PduCount::Normal(from), pdu))
|
Ok((PduCount::Normal(from), pdu))
|
||||||
}),
|
},
|
||||||
|
),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,8 +75,6 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
|
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
|
||||||
self.softfailedeventids
|
self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some())
|
||||||
.get(event_id.as_bytes())
|
|
||||||
.map(|o| o.is_some())
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -14,7 +14,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
|
||||||
.map(|word| {
|
.map(|word| {
|
||||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||||
key.extend_from_slice(word.as_bytes());
|
key.extend_from_slice(word.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
|
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
|
||||||
(key, Vec::new())
|
(key, Vec::new())
|
||||||
});
|
});
|
||||||
|
@ -23,13 +23,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
|
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
|
||||||
let prefix = services()
|
let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec();
|
||||||
.rooms
|
|
||||||
.short
|
|
||||||
.get_shortroomid(room_id)?
|
|
||||||
.expect("room exists")
|
|
||||||
.to_be_bytes()
|
|
||||||
.to_vec();
|
|
||||||
|
|
||||||
let words: Vec<_> = search_string
|
let words: Vec<_> = search_string
|
||||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||||
|
@ -40,7 +34,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
|
||||||
let iterators = words.clone().into_iter().map(move |word| {
|
let iterators = words.clone().into_iter().map(move |word| {
|
||||||
let mut prefix2 = prefix.clone();
|
let mut prefix2 = prefix.clone();
|
||||||
prefix2.extend_from_slice(word.as_bytes());
|
prefix2.extend_from_slice(word.as_bytes());
|
||||||
prefix2.push(0xff);
|
prefix2.push(0xFF);
|
||||||
let prefix3 = prefix2.clone();
|
let prefix3 = prefix2.clone();
|
||||||
|
|
||||||
let mut last_possible_id = prefix2.clone();
|
let mut last_possible_id = prefix2.clone();
|
||||||
|
|
|
@ -12,79 +12,57 @@ impl service::rooms::short::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
|
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
|
||||||
Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
|
Some(shorteventid) => {
|
||||||
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
|
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
|
||||||
|
},
|
||||||
None => {
|
None => {
|
||||||
let shorteventid = services().globals.next_count()?;
|
let shorteventid = services().globals.next_count()?;
|
||||||
self.eventid_shorteventid
|
self.eventid_shorteventid.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
|
||||||
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
|
self.shorteventid_eventid.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
|
||||||
self.shorteventid_eventid
|
|
||||||
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
|
|
||||||
shorteventid
|
shorteventid
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
self.eventidshort_cache
|
self.eventidshort_cache.lock().unwrap().insert(event_id.to_owned(), short);
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(event_id.to_owned(), short);
|
|
||||||
|
|
||||||
Ok(short)
|
Ok(short)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_shortstatekey(
|
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
|
||||||
&self,
|
if let Some(short) =
|
||||||
event_type: &StateEventType,
|
self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||||
state_key: &str,
|
|
||||||
) -> Result<Option<u64>> {
|
|
||||||
if let Some(short) = self
|
|
||||||
.statekeyshort_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.get_mut(&(event_type.clone(), state_key.to_owned()))
|
|
||||||
{
|
{
|
||||||
return Ok(Some(*short));
|
return Ok(Some(*short));
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||||
statekey_vec.push(0xff);
|
statekey_vec.push(0xFF);
|
||||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||||
|
|
||||||
let short = self
|
let short = self
|
||||||
.statekey_shortstatekey
|
.statekey_shortstatekey
|
||||||
.get(&statekey_vec)?
|
.get(&statekey_vec)?
|
||||||
.map(|shortstatekey| {
|
.map(|shortstatekey| {
|
||||||
utils::u64_from_bytes(&shortstatekey)
|
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
|
||||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
|
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
|
||||||
if let Some(s) = short {
|
if let Some(s) = short {
|
||||||
self.statekeyshort_cache
|
self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), s);
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert((event_type.clone(), state_key.to_owned()), s);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(short)
|
Ok(short)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_or_create_shortstatekey(
|
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
|
||||||
&self,
|
if let Some(short) =
|
||||||
event_type: &StateEventType,
|
self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||||
state_key: &str,
|
|
||||||
) -> Result<u64> {
|
|
||||||
if let Some(short) = self
|
|
||||||
.statekeyshort_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.get_mut(&(event_type.clone(), state_key.to_owned()))
|
|
||||||
{
|
{
|
||||||
return Ok(*short);
|
return Ok(*short);
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||||
statekey_vec.push(0xff);
|
statekey_vec.push(0xFF);
|
||||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||||
|
|
||||||
let short = match self.statekey_shortstatekey.get(&statekey_vec)? {
|
let short = match self.statekey_shortstatekey.get(&statekey_vec)? {
|
||||||
|
@ -92,29 +70,19 @@ impl service::rooms::short::Data for KeyValueDatabase {
|
||||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
|
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
|
||||||
None => {
|
None => {
|
||||||
let shortstatekey = services().globals.next_count()?;
|
let shortstatekey = services().globals.next_count()?;
|
||||||
self.statekey_shortstatekey
|
self.statekey_shortstatekey.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
|
||||||
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
|
self.shortstatekey_statekey.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
|
||||||
self.shortstatekey_statekey
|
|
||||||
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
|
|
||||||
shortstatekey
|
shortstatekey
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
self.statekeyshort_cache
|
self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), short);
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert((event_type.clone(), state_key.to_owned()), short);
|
|
||||||
|
|
||||||
Ok(short)
|
Ok(short)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
|
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
|
||||||
if let Some(id) = self
|
if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) {
|
||||||
.shorteventid_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.get_mut(&shorteventid)
|
|
||||||
{
|
|
||||||
return Ok(Arc::clone(id));
|
return Ok(Arc::clone(id));
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -123,26 +91,19 @@ impl service::rooms::short::Data for KeyValueDatabase {
|
||||||
.get(&shorteventid.to_be_bytes())?
|
.get(&shorteventid.to_be_bytes())?
|
||||||
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
|
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
|
||||||
|
|
||||||
let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
|
let event_id = EventId::parse_arc(
|
||||||
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
|
utils::string_from_bytes(&bytes)
|
||||||
})?)
|
.map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
|
||||||
|
)
|
||||||
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
|
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
|
||||||
|
|
||||||
self.shorteventid_cache
|
self.shorteventid_cache.lock().unwrap().insert(shorteventid, Arc::clone(&event_id));
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(shorteventid, Arc::clone(&event_id));
|
|
||||||
|
|
||||||
Ok(event_id)
|
Ok(event_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
|
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
|
||||||
if let Some(id) = self
|
if let Some(id) = self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey) {
|
||||||
.shortstatekey_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.get_mut(&shortstatekey)
|
|
||||||
{
|
|
||||||
return Ok(id.clone());
|
return Ok(id.clone());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -151,28 +112,22 @@ impl service::rooms::short::Data for KeyValueDatabase {
|
||||||
.get(&shortstatekey.to_be_bytes())?
|
.get(&shortstatekey.to_be_bytes())?
|
||||||
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
|
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
|
||||||
|
|
||||||
let mut parts = bytes.splitn(2, |&b| b == 0xff);
|
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
|
||||||
let eventtype_bytes = parts.next().expect("split always returns one entry");
|
let eventtype_bytes = parts.next().expect("split always returns one entry");
|
||||||
let statekey_bytes = parts
|
let statekey_bytes =
|
||||||
.next()
|
parts.next().ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
|
||||||
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
|
|
||||||
|
|
||||||
let event_type =
|
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
|
||||||
StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
|
|
||||||
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
|
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
|
||||||
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
|
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
|
||||||
})?);
|
})?);
|
||||||
|
|
||||||
let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
|
let state_key = utils::string_from_bytes(statekey_bytes)
|
||||||
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
|
.map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
|
||||||
})?;
|
|
||||||
|
|
||||||
let result = (event_type, state_key);
|
let result = (event_type, state_key);
|
||||||
|
|
||||||
self.shortstatekey_cache
|
self.shortstatekey_cache.lock().unwrap().insert(shortstatekey, result.clone());
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(shortstatekey, result.clone());
|
|
||||||
|
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
@ -187,33 +142,29 @@ impl service::rooms::short::Data for KeyValueDatabase {
|
||||||
),
|
),
|
||||||
None => {
|
None => {
|
||||||
let shortstatehash = services().globals.next_count()?;
|
let shortstatehash = services().globals.next_count()?;
|
||||||
self.statehash_shortstatehash
|
self.statehash_shortstatehash.insert(state_hash, &shortstatehash.to_be_bytes())?;
|
||||||
.insert(state_hash, &shortstatehash.to_be_bytes())?;
|
|
||||||
(shortstatehash, false)
|
(shortstatehash, false)
|
||||||
}
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||||
self.roomid_shortroomid
|
self.roomid_shortroomid
|
||||||
.get(room_id.as_bytes())?
|
.get(room_id.as_bytes())?
|
||||||
.map(|bytes| {
|
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
|
||||||
utils::u64_from_bytes(&bytes)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))
|
|
||||||
})
|
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
|
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
|
||||||
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
|
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
|
||||||
Some(short) => utils::u64_from_bytes(&short)
|
Some(short) => {
|
||||||
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
|
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
|
||||||
|
},
|
||||||
None => {
|
None => {
|
||||||
let short = services().globals.next_count()?;
|
let short = services().globals.next_count()?;
|
||||||
self.roomid_shortroomid
|
self.roomid_shortroomid.insert(room_id.as_bytes(), &short.to_be_bytes())?;
|
||||||
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
|
|
||||||
short
|
short
|
||||||
}
|
},
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,16 +1,13 @@
|
||||||
use ruma::{EventId, OwnedEventId, RoomId};
|
use std::{collections::HashSet, sync::Arc};
|
||||||
use std::collections::HashSet;
|
|
||||||
|
|
||||||
use std::sync::Arc;
|
use ruma::{EventId, OwnedEventId, RoomId};
|
||||||
use tokio::sync::MutexGuard;
|
use tokio::sync::MutexGuard;
|
||||||
|
|
||||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||||
|
|
||||||
impl service::rooms::state::Data for KeyValueDatabase {
|
impl service::rooms::state::Data for KeyValueDatabase {
|
||||||
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||||
self.roomid_shortstatehash
|
self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||||
.get(room_id.as_bytes())?
|
|
||||||
.map_or(Ok(None), |bytes| {
|
|
||||||
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||||
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
|
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
|
||||||
})?))
|
})?))
|
||||||
|
@ -23,27 +20,26 @@ impl service::rooms::state::Data for KeyValueDatabase {
|
||||||
new_shortstatehash: u64,
|
new_shortstatehash: u64,
|
||||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
self.roomid_shortstatehash
|
self.roomid_shortstatehash.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
|
||||||
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
|
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
|
||||||
self.shorteventid_shortstatehash
|
self.shorteventid_shortstatehash.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
|
||||||
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
|
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
self.roomid_pduleaves
|
self.roomid_pduleaves
|
||||||
.scan_prefix(prefix)
|
.scan_prefix(prefix)
|
||||||
.map(|(_, bytes)| {
|
.map(|(_, bytes)| {
|
||||||
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
|
EventId::parse_arc(
|
||||||
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
|
utils::string_from_bytes(&bytes)
|
||||||
})?)
|
.map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
|
||||||
|
)
|
||||||
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
|
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
|
||||||
})
|
})
|
||||||
.collect()
|
.collect()
|
||||||
|
@ -56,7 +52,7 @@ impl service::rooms::state::Data for KeyValueDatabase {
|
||||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
|
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
|
||||||
self.roomid_pduleaves.remove(&key)?;
|
self.roomid_pduleaves.remove(&key)?;
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
use std::{collections::HashMap, sync::Arc};
|
use std::{collections::HashMap, sync::Arc};
|
||||||
|
|
||||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
|
||||||
use async_trait::async_trait;
|
use async_trait::async_trait;
|
||||||
use ruma::{events::StateEventType, EventId, RoomId};
|
use ruma::{events::StateEventType, EventId, RoomId};
|
||||||
|
|
||||||
|
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||||
|
|
||||||
#[async_trait]
|
#[async_trait]
|
||||||
impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
||||||
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
|
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
|
||||||
|
@ -17,10 +18,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
||||||
let mut result = HashMap::new();
|
let mut result = HashMap::new();
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
for compressed in full_state.iter() {
|
for compressed in full_state.iter() {
|
||||||
let parsed = services()
|
let parsed = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
|
||||||
.rooms
|
|
||||||
.state_compressor
|
|
||||||
.parse_compressed_state_event(compressed)?;
|
|
||||||
result.insert(parsed.0, parsed.1);
|
result.insert(parsed.0, parsed.1);
|
||||||
|
|
||||||
i += 1;
|
i += 1;
|
||||||
|
@ -31,10 +29,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn state_full(
|
async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||||
&self,
|
|
||||||
shortstatehash: u64,
|
|
||||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
|
||||||
let full_state = services()
|
let full_state = services()
|
||||||
.rooms
|
.rooms
|
||||||
.state_compressor
|
.state_compressor
|
||||||
|
@ -46,10 +41,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
||||||
let mut result = HashMap::new();
|
let mut result = HashMap::new();
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
for compressed in full_state.iter() {
|
for compressed in full_state.iter() {
|
||||||
let (_, eventid) = services()
|
let (_, eventid) = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
|
||||||
.rooms
|
|
||||||
.state_compressor
|
|
||||||
.parse_compressed_state_event(compressed)?;
|
|
||||||
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
|
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
|
||||||
result.insert(
|
result.insert(
|
||||||
(
|
(
|
||||||
|
@ -72,18 +64,12 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
||||||
Ok(result)
|
Ok(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||||
|
/// `state_key`).
|
||||||
fn state_get_id(
|
fn state_get_id(
|
||||||
&self,
|
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
|
||||||
shortstatehash: u64,
|
|
||||||
event_type: &StateEventType,
|
|
||||||
state_key: &str,
|
|
||||||
) -> Result<Option<Arc<EventId>>> {
|
) -> Result<Option<Arc<EventId>>> {
|
||||||
let shortstatekey = match services()
|
let shortstatekey = match services().rooms.short.get_shortstatekey(event_type, state_key)? {
|
||||||
.rooms
|
|
||||||
.short
|
|
||||||
.get_shortstatekey(event_type, state_key)?
|
|
||||||
{
|
|
||||||
Some(s) => s,
|
Some(s) => s,
|
||||||
None => return Ok(None),
|
None => return Ok(None),
|
||||||
};
|
};
|
||||||
|
@ -94,90 +80,62 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
||||||
.pop()
|
.pop()
|
||||||
.expect("there is always one layer")
|
.expect("there is always one layer")
|
||||||
.1;
|
.1;
|
||||||
Ok(full_state
|
Ok(
|
||||||
.iter()
|
full_state.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())).and_then(|compressed| {
|
||||||
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
|
services().rooms.state_compressor.parse_compressed_state_event(compressed).ok().map(|(_, id)| id)
|
||||||
.and_then(|compressed| {
|
}),
|
||||||
services()
|
)
|
||||||
.rooms
|
|
||||||
.state_compressor
|
|
||||||
.parse_compressed_state_event(compressed)
|
|
||||||
.ok()
|
|
||||||
.map(|(_, id)| id)
|
|
||||||
}))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||||
|
/// `state_key`).
|
||||||
fn state_get(
|
fn state_get(
|
||||||
&self,
|
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
|
||||||
shortstatehash: u64,
|
|
||||||
event_type: &StateEventType,
|
|
||||||
state_key: &str,
|
|
||||||
) -> Result<Option<Arc<PduEvent>>> {
|
) -> Result<Option<Arc<PduEvent>>> {
|
||||||
self.state_get_id(shortstatehash, event_type, state_key)?
|
self.state_get_id(shortstatehash, event_type, state_key)?
|
||||||
.map_or(Ok(None), |event_id| {
|
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
|
||||||
services().rooms.timeline.get_pdu(&event_id)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the state hash for this pdu.
|
/// Returns the state hash for this pdu.
|
||||||
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
|
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||||
self.eventid_shorteventid
|
self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(Ok(None), |shorteventid| {
|
||||||
.get(event_id.as_bytes())?
|
|
||||||
.map_or(Ok(None), |shorteventid| {
|
|
||||||
self.shorteventid_shortstatehash
|
self.shorteventid_shortstatehash
|
||||||
.get(&shorteventid)?
|
.get(&shorteventid)?
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
utils::u64_from_bytes(&bytes)
|
||||||
Error::bad_database(
|
.map_err(|_| Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash"))
|
||||||
"Invalid shortstatehash bytes in shorteventid_shortstatehash",
|
|
||||||
)
|
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.transpose()
|
.transpose()
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the full room state.
|
/// Returns the full room state.
|
||||||
async fn room_state_full(
|
async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||||
&self,
|
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
|
||||||
if let Some(current_shortstatehash) =
|
|
||||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
|
||||||
{
|
|
||||||
self.state_full(current_shortstatehash).await
|
self.state_full(current_shortstatehash).await
|
||||||
} else {
|
} else {
|
||||||
Ok(HashMap::new())
|
Ok(HashMap::new())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||||
|
/// `state_key`).
|
||||||
fn room_state_get_id(
|
fn room_state_get_id(
|
||||||
&self,
|
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||||
room_id: &RoomId,
|
|
||||||
event_type: &StateEventType,
|
|
||||||
state_key: &str,
|
|
||||||
) -> Result<Option<Arc<EventId>>> {
|
) -> Result<Option<Arc<EventId>>> {
|
||||||
if let Some(current_shortstatehash) =
|
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
|
||||||
{
|
|
||||||
self.state_get_id(current_shortstatehash, event_type, state_key)
|
self.state_get_id(current_shortstatehash, event_type, state_key)
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||||
|
/// `state_key`).
|
||||||
fn room_state_get(
|
fn room_state_get(
|
||||||
&self,
|
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||||
room_id: &RoomId,
|
|
||||||
event_type: &StateEventType,
|
|
||||||
state_key: &str,
|
|
||||||
) -> Result<Option<Arc<PduEvent>>> {
|
) -> Result<Option<Arc<PduEvent>>> {
|
||||||
if let Some(current_shortstatehash) =
|
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
|
||||||
{
|
|
||||||
self.state_get(current_shortstatehash, event_type, state_key)
|
self.state_get(current_shortstatehash, event_type, state_key)
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
|
|
|
@ -10,27 +10,25 @@ use ruma::{
|
||||||
|
|
||||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||||
|
|
||||||
type StrippedStateEventIter<'a> =
|
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
|
||||||
Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
|
|
||||||
|
|
||||||
type AnySyncStateEventIter<'a> =
|
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
|
||||||
Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
|
|
||||||
|
|
||||||
impl service::rooms::state_cache::Data for KeyValueDatabase {
|
impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
self.roomuseroncejoinedids.insert(&userroom_id, &[])
|
self.roomuseroncejoinedids.insert(&userroom_id, &[])
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||||
roomuser_id.push(0xff);
|
roomuser_id.push(0xFF);
|
||||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
self.userroomid_joined.insert(&userroom_id, &[])?;
|
self.userroomid_joined.insert(&userroom_id, &[])?;
|
||||||
|
@ -44,28 +42,21 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mark_as_invited(
|
fn mark_as_invited(
|
||||||
&self,
|
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
||||||
user_id: &UserId,
|
|
||||||
room_id: &RoomId,
|
|
||||||
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||||
roomuser_id.push(0xff);
|
roomuser_id.push(0xFF);
|
||||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
self.userroomid_invitestate.insert(
|
self.userroomid_invitestate.insert(
|
||||||
&userroom_id,
|
&userroom_id,
|
||||||
&serde_json::to_vec(&last_state.unwrap_or_default())
|
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
|
||||||
.expect("state to bytes always works"),
|
|
||||||
)?;
|
|
||||||
self.roomuserid_invitecount.insert(
|
|
||||||
&roomuser_id,
|
|
||||||
&services().globals.next_count()?.to_be_bytes(),
|
|
||||||
)?;
|
)?;
|
||||||
|
self.roomuserid_invitecount.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
|
||||||
self.userroomid_joined.remove(&userroom_id)?;
|
self.userroomid_joined.remove(&userroom_id)?;
|
||||||
self.roomuserid_joined.remove(&roomuser_id)?;
|
self.roomuserid_joined.remove(&roomuser_id)?;
|
||||||
self.userroomid_leftstate.remove(&userroom_id)?;
|
self.userroomid_leftstate.remove(&userroom_id)?;
|
||||||
|
@ -76,21 +67,18 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
|
|
||||||
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||||
roomuser_id.push(0xff);
|
roomuser_id.push(0xFF);
|
||||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
self.userroomid_leftstate.insert(
|
self.userroomid_leftstate.insert(
|
||||||
&userroom_id,
|
&userroom_id,
|
||||||
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
|
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
|
||||||
)?; // TODO
|
)?; // TODO
|
||||||
self.roomuserid_leftcount.insert(
|
self.roomuserid_leftcount.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
|
||||||
&roomuser_id,
|
|
||||||
&services().globals.next_count()?.to_be_bytes(),
|
|
||||||
)?;
|
|
||||||
self.userroomid_joined.remove(&userroom_id)?;
|
self.userroomid_joined.remove(&userroom_id)?;
|
||||||
self.roomuserid_joined.remove(&roomuser_id)?;
|
self.roomuserid_joined.remove(&roomuser_id)?;
|
||||||
self.userroomid_invitestate.remove(&userroom_id)?;
|
self.userroomid_invitestate.remove(&userroom_id)?;
|
||||||
|
@ -105,10 +93,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
let mut joined_servers = HashSet::new();
|
let mut joined_servers = HashSet::new();
|
||||||
let mut real_users = HashSet::new();
|
let mut real_users = HashSet::new();
|
||||||
|
|
||||||
for joined in self
|
for joined in self.room_members(room_id).filter_map(std::result::Result::ok) {
|
||||||
.room_members(room_id)
|
|
||||||
.filter_map(std::result::Result::ok)
|
|
||||||
{
|
|
||||||
joined_servers.insert(joined.server_name().to_owned());
|
joined_servers.insert(joined.server_name().to_owned());
|
||||||
if joined.server_name() == services().globals.server_name()
|
if joined.server_name() == services().globals.server_name()
|
||||||
&& !services().users.is_deactivated(&joined).unwrap_or(true)
|
&& !services().users.is_deactivated(&joined).unwrap_or(true)
|
||||||
|
@ -118,36 +103,25 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
joinedcount += 1;
|
joinedcount += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
for _invited in self
|
for _invited in self.room_members_invited(room_id).filter_map(std::result::Result::ok) {
|
||||||
.room_members_invited(room_id)
|
|
||||||
.filter_map(std::result::Result::ok)
|
|
||||||
{
|
|
||||||
invitedcount += 1;
|
invitedcount += 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.roomid_joinedcount
|
self.roomid_joinedcount.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
|
||||||
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
|
|
||||||
|
|
||||||
self.roomid_invitedcount
|
self.roomid_invitedcount.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
|
||||||
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
|
|
||||||
|
|
||||||
self.our_real_users_cache
|
self.our_real_users_cache.write().unwrap().insert(room_id.to_owned(), Arc::new(real_users));
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.insert(room_id.to_owned(), Arc::new(real_users));
|
|
||||||
|
|
||||||
for old_joined_server in self
|
for old_joined_server in self.room_servers(room_id).filter_map(std::result::Result::ok) {
|
||||||
.room_servers(room_id)
|
|
||||||
.filter_map(std::result::Result::ok)
|
|
||||||
{
|
|
||||||
if !joined_servers.remove(&old_joined_server) {
|
if !joined_servers.remove(&old_joined_server) {
|
||||||
// Server not in room anymore
|
// Server not in room anymore
|
||||||
let mut roomserver_id = room_id.as_bytes().to_vec();
|
let mut roomserver_id = room_id.as_bytes().to_vec();
|
||||||
roomserver_id.push(0xff);
|
roomserver_id.push(0xFF);
|
||||||
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
|
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
|
||||||
|
|
||||||
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
|
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
|
||||||
serverroom_id.push(0xff);
|
serverroom_id.push(0xFF);
|
||||||
serverroom_id.extend_from_slice(room_id.as_bytes());
|
serverroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
self.roomserverids.remove(&roomserver_id)?;
|
self.roomserverids.remove(&roomserver_id)?;
|
||||||
|
@ -158,70 +132,44 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
// Now only new servers are in joined_servers anymore
|
// Now only new servers are in joined_servers anymore
|
||||||
for server in joined_servers {
|
for server in joined_servers {
|
||||||
let mut roomserver_id = room_id.as_bytes().to_vec();
|
let mut roomserver_id = room_id.as_bytes().to_vec();
|
||||||
roomserver_id.push(0xff);
|
roomserver_id.push(0xFF);
|
||||||
roomserver_id.extend_from_slice(server.as_bytes());
|
roomserver_id.extend_from_slice(server.as_bytes());
|
||||||
|
|
||||||
let mut serverroom_id = server.as_bytes().to_vec();
|
let mut serverroom_id = server.as_bytes().to_vec();
|
||||||
serverroom_id.push(0xff);
|
serverroom_id.push(0xFF);
|
||||||
serverroom_id.extend_from_slice(room_id.as_bytes());
|
serverroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
self.roomserverids.insert(&roomserver_id, &[])?;
|
self.roomserverids.insert(&roomserver_id, &[])?;
|
||||||
self.serverroomids.insert(&serverroom_id, &[])?;
|
self.serverroomids.insert(&serverroom_id, &[])?;
|
||||||
}
|
}
|
||||||
|
|
||||||
self.appservice_in_room_cache
|
self.appservice_in_room_cache.write().unwrap().remove(room_id);
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.remove(room_id);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self, room_id))]
|
#[tracing::instrument(skip(self, room_id))]
|
||||||
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> {
|
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> {
|
||||||
let maybe = self
|
let maybe = self.our_real_users_cache.read().unwrap().get(room_id).cloned();
|
||||||
.our_real_users_cache
|
|
||||||
.read()
|
|
||||||
.unwrap()
|
|
||||||
.get(room_id)
|
|
||||||
.cloned();
|
|
||||||
if let Some(users) = maybe {
|
if let Some(users) = maybe {
|
||||||
Ok(users)
|
Ok(users)
|
||||||
} else {
|
} else {
|
||||||
self.update_joined_count(room_id)?;
|
self.update_joined_count(room_id)?;
|
||||||
Ok(Arc::clone(
|
Ok(Arc::clone(self.our_real_users_cache.read().unwrap().get(room_id).unwrap()))
|
||||||
self.our_real_users_cache
|
|
||||||
.read()
|
|
||||||
.unwrap()
|
|
||||||
.get(room_id)
|
|
||||||
.unwrap(),
|
|
||||||
))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self, room_id, appservice))]
|
#[tracing::instrument(skip(self, room_id, appservice))]
|
||||||
fn appservice_in_room(
|
fn appservice_in_room(&self, room_id: &RoomId, appservice: &(String, Registration)) -> Result<bool> {
|
||||||
&self,
|
let maybe =
|
||||||
room_id: &RoomId,
|
self.appservice_in_room_cache.read().unwrap().get(room_id).and_then(|map| map.get(&appservice.0)).copied();
|
||||||
appservice: &(String, Registration),
|
|
||||||
) -> Result<bool> {
|
|
||||||
let maybe = self
|
|
||||||
.appservice_in_room_cache
|
|
||||||
.read()
|
|
||||||
.unwrap()
|
|
||||||
.get(room_id)
|
|
||||||
.and_then(|map| map.get(&appservice.0))
|
|
||||||
.copied();
|
|
||||||
|
|
||||||
if let Some(b) = maybe {
|
if let Some(b) = maybe {
|
||||||
Ok(b)
|
Ok(b)
|
||||||
} else {
|
} else {
|
||||||
let namespaces = &appservice.1.namespaces;
|
let namespaces = &appservice.1.namespaces;
|
||||||
let users = namespaces
|
let users =
|
||||||
.users
|
namespaces.users.iter().filter_map(|users| Regex::new(users.regex.as_str()).ok()).collect::<Vec<_>>();
|
||||||
.iter()
|
|
||||||
.filter_map(|users| Regex::new(users.regex.as_str()).ok())
|
|
||||||
.collect::<Vec<_>>();
|
|
||||||
|
|
||||||
let bridge_user_id = UserId::parse_with_server_name(
|
let bridge_user_id = UserId::parse_with_server_name(
|
||||||
appservice.1.sender_localpart.as_str(),
|
appservice.1.sender_localpart.as_str(),
|
||||||
|
@ -229,13 +177,10 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
)
|
)
|
||||||
.ok();
|
.ok();
|
||||||
|
|
||||||
let in_room = bridge_user_id
|
let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|
||||||
.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|
|| self
|
||||||
|| self.room_members(room_id).any(|userid| {
|
.room_members(room_id)
|
||||||
userid.map_or(false, |userid| {
|
.any(|userid| userid.map_or(false, |userid| users.iter().any(|r| r.is_match(userid.as_str()))));
|
||||||
users.iter().any(|r| r.is_match(userid.as_str()))
|
|
||||||
})
|
|
||||||
});
|
|
||||||
|
|
||||||
self.appservice_in_room_cache
|
self.appservice_in_room_cache
|
||||||
.write()
|
.write()
|
||||||
|
@ -252,11 +197,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
|
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||||
roomuser_id.push(0xff);
|
roomuser_id.push(0xFF);
|
||||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
self.userroomid_leftstate.remove(&userroom_id)?;
|
self.userroomid_leftstate.remove(&userroom_id)?;
|
||||||
|
@ -267,23 +212,14 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
|
|
||||||
/// Returns an iterator of all servers participating in this room.
|
/// Returns an iterator of all servers participating in this room.
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn room_servers<'a>(
|
fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
|
||||||
&'a self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
|
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
|
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
|
||||||
ServerName::parse(
|
ServerName::parse(
|
||||||
utils::string_from_bytes(
|
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||||
key.rsplit(|&b| b == 0xff)
|
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?,
|
||||||
.next()
|
|
||||||
.expect("rsplit always returns an element"),
|
|
||||||
)
|
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database("Server name in roomserverids is invalid unicode.")
|
|
||||||
})?,
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
|
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
|
||||||
}))
|
}))
|
||||||
|
@ -292,28 +228,22 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
|
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
|
||||||
let mut key = server.as_bytes().to_vec();
|
let mut key = server.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(room_id.as_bytes());
|
key.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
self.serverroomids.get(&key).map(|o| o.is_some())
|
self.serverroomids.get(&key).map(|o| o.is_some())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an iterator of all rooms a server participates in (as far as we know).
|
/// Returns an iterator of all rooms a server participates in (as far as we
|
||||||
|
/// know).
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn server_rooms<'a>(
|
fn server_rooms<'a>(&'a self, server: &ServerName) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||||
&'a self,
|
|
||||||
server: &ServerName,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
|
||||||
let mut prefix = server.as_bytes().to_vec();
|
let mut prefix = server.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
|
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
|
||||||
RoomId::parse(
|
RoomId::parse(
|
||||||
utils::string_from_bytes(
|
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||||
key.rsplit(|&b| b == 0xff)
|
|
||||||
.next()
|
|
||||||
.expect("rsplit always returns an element"),
|
|
||||||
)
|
|
||||||
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
|
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
|
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
|
||||||
|
@ -322,23 +252,14 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
|
|
||||||
/// Returns an iterator over all joined members of a room.
|
/// Returns an iterator over all joined members of a room.
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn room_members<'a>(
|
fn room_members<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||||
&'a self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
|
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
|
||||||
UserId::parse(
|
UserId::parse(
|
||||||
utils::string_from_bytes(
|
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||||
key.rsplit(|&b| b == 0xff)
|
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?,
|
||||||
.next()
|
|
||||||
.expect("rsplit always returns an element"),
|
|
||||||
)
|
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database("User ID in roomuserid_joined is invalid unicode.")
|
|
||||||
})?,
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
|
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
|
||||||
}))
|
}))
|
||||||
|
@ -348,10 +269,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||||
self.roomid_joinedcount
|
self.roomid_joinedcount
|
||||||
.get(room_id.as_bytes())?
|
.get(room_id.as_bytes())?
|
||||||
.map(|b| {
|
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
|
||||||
utils::u64_from_bytes(&b)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid joinedcount in db."))
|
|
||||||
})
|
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -359,167 +277,101 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||||
self.roomid_invitedcount
|
self.roomid_invitedcount
|
||||||
.get(room_id.as_bytes())?
|
.get(room_id.as_bytes())?
|
||||||
.map(|b| {
|
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
|
||||||
utils::u64_from_bytes(&b)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid joinedcount in db."))
|
|
||||||
})
|
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an iterator over all User IDs who ever joined a room.
|
/// Returns an iterator over all User IDs who ever joined a room.
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn room_useroncejoined<'a>(
|
fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||||
&'a self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
Box::new(
|
Box::new(self.roomuseroncejoinedids.scan_prefix(prefix).map(|(key, _)| {
|
||||||
self.roomuseroncejoinedids
|
|
||||||
.scan_prefix(prefix)
|
|
||||||
.map(|(key, _)| {
|
|
||||||
UserId::parse(
|
UserId::parse(
|
||||||
utils::string_from_bytes(
|
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||||
key.rsplit(|&b| b == 0xff)
|
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?,
|
||||||
.next()
|
|
||||||
.expect("rsplit always returns an element"),
|
|
||||||
)
|
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database(
|
|
||||||
"User ID in room_useroncejoined is invalid unicode.",
|
|
||||||
)
|
|
||||||
})?,
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
|
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
|
||||||
}),
|
}))
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an iterator over all invited members of a room.
|
/// Returns an iterator over all invited members of a room.
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn room_members_invited<'a>(
|
fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||||
&'a self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
|
||||||
let mut prefix = room_id.as_bytes().to_vec();
|
let mut prefix = room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
Box::new(
|
Box::new(self.roomuserid_invitecount.scan_prefix(prefix).map(|(key, _)| {
|
||||||
self.roomuserid_invitecount
|
|
||||||
.scan_prefix(prefix)
|
|
||||||
.map(|(key, _)| {
|
|
||||||
UserId::parse(
|
UserId::parse(
|
||||||
utils::string_from_bytes(
|
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||||
key.rsplit(|&b| b == 0xff)
|
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?,
|
||||||
.next()
|
|
||||||
.expect("rsplit always returns an element"),
|
|
||||||
)
|
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database("User ID in roomuserid_invited is invalid unicode.")
|
|
||||||
})?,
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
|
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
|
||||||
}),
|
}))
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||||
let mut key = room_id.as_bytes().to_vec();
|
let mut key = room_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(user_id.as_bytes());
|
key.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
self.roomuserid_invitecount
|
self.roomuserid_invitecount.get(&key)?.map_or(Ok(None), |bytes| {
|
||||||
.get(&key)?
|
Ok(Some(
|
||||||
.map_or(Ok(None), |bytes| {
|
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?,
|
||||||
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
))
|
||||||
Error::bad_database("Invalid invitecount in db.")
|
|
||||||
})?))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||||
let mut key = room_id.as_bytes().to_vec();
|
let mut key = room_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(user_id.as_bytes());
|
key.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
self.roomuserid_leftcount
|
self.roomuserid_leftcount
|
||||||
.get(&key)?
|
.get(&key)?
|
||||||
.map(|bytes| {
|
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db.")))
|
||||||
utils::u64_from_bytes(&bytes)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid leftcount in db."))
|
|
||||||
})
|
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an iterator over all rooms this user joined.
|
/// Returns an iterator over all rooms this user joined.
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn rooms_joined<'a>(
|
fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||||
&'a self,
|
Box::new(self.userroomid_joined.scan_prefix(user_id.as_bytes().to_vec()).map(|(key, _)| {
|
||||||
user_id: &UserId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
|
||||||
Box::new(
|
|
||||||
self.userroomid_joined
|
|
||||||
.scan_prefix(user_id.as_bytes().to_vec())
|
|
||||||
.map(|(key, _)| {
|
|
||||||
RoomId::parse(
|
RoomId::parse(
|
||||||
utils::string_from_bytes(
|
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||||
key.rsplit(|&b| b == 0xff)
|
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?,
|
||||||
.next()
|
|
||||||
.expect("rsplit always returns an element"),
|
|
||||||
)
|
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database("Room ID in userroomid_joined is invalid unicode.")
|
|
||||||
})?,
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
|
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
|
||||||
}),
|
}))
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an iterator over all rooms a user was invited to.
|
/// Returns an iterator over all rooms a user was invited to.
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> {
|
fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> {
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
Box::new(
|
Box::new(self.userroomid_invitestate.scan_prefix(prefix).map(|(key, state)| {
|
||||||
self.userroomid_invitestate
|
|
||||||
.scan_prefix(prefix)
|
|
||||||
.map(|(key, state)| {
|
|
||||||
let room_id = RoomId::parse(
|
let room_id = RoomId::parse(
|
||||||
utils::string_from_bytes(
|
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||||
key.rsplit(|&b| b == 0xff)
|
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
|
||||||
.next()
|
|
||||||
.expect("rsplit always returns an element"),
|
|
||||||
)
|
)
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
|
||||||
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
|
|
||||||
})?,
|
|
||||||
)
|
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database("Room ID in userroomid_invited is invalid.")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let state = serde_json::from_slice(&state).map_err(|_| {
|
let state = serde_json::from_slice(&state)
|
||||||
Error::bad_database("Invalid state in userroomid_invitestate.")
|
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok((room_id, state))
|
Ok((room_id, state))
|
||||||
}),
|
}))
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn invite_state(
|
fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
|
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(room_id.as_bytes());
|
key.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
self.userroomid_invitestate
|
self.userroomid_invitestate
|
||||||
|
@ -534,13 +386,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn left_state(
|
fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
|
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(room_id.as_bytes());
|
key.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
self.userroomid_leftstate
|
self.userroomid_leftstate
|
||||||
|
@ -558,39 +406,26 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> {
|
fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> {
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
Box::new(
|
Box::new(self.userroomid_leftstate.scan_prefix(prefix).map(|(key, state)| {
|
||||||
self.userroomid_leftstate
|
|
||||||
.scan_prefix(prefix)
|
|
||||||
.map(|(key, state)| {
|
|
||||||
let room_id = RoomId::parse(
|
let room_id = RoomId::parse(
|
||||||
utils::string_from_bytes(
|
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||||
key.rsplit(|&b| b == 0xff)
|
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
|
||||||
.next()
|
|
||||||
.expect("rsplit always returns an element"),
|
|
||||||
)
|
)
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
|
||||||
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
|
|
||||||
})?,
|
|
||||||
)
|
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database("Room ID in userroomid_invited is invalid.")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
let state = serde_json::from_slice(&state).map_err(|_| {
|
let state = serde_json::from_slice(&state)
|
||||||
Error::bad_database("Invalid state in userroomid_leftstate.")
|
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok((room_id, state))
|
Ok((room_id, state))
|
||||||
}),
|
}))
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
|
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
|
||||||
|
@ -599,7 +434,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
|
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
|
||||||
|
@ -608,7 +443,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
|
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
|
||||||
|
@ -617,7 +452,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
|
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
|
||||||
|
|
|
@ -12,9 +12,12 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
|
||||||
.shortstatehash_statediff
|
.shortstatehash_statediff
|
||||||
.get(&shortstatehash.to_be_bytes())?
|
.get(&shortstatehash.to_be_bytes())?
|
||||||
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
|
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
|
||||||
let parent =
|
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
|
||||||
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
|
let parent = if parent != 0 {
|
||||||
let parent = if parent != 0 { Some(parent) } else { None };
|
Some(parent)
|
||||||
|
} else {
|
||||||
|
None
|
||||||
|
};
|
||||||
|
|
||||||
let mut add_mode = true;
|
let mut add_mode = true;
|
||||||
let mut added = HashSet::new();
|
let mut added = HashSet::new();
|
||||||
|
@ -55,7 +58,6 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
self.shortstatehash_statediff
|
self.shortstatehash_statediff.insert(&shortstatehash.to_be_bytes(), &value)
|
||||||
.insert(&shortstatehash.to_be_bytes(), &value)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -8,51 +8,34 @@ type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEve
|
||||||
|
|
||||||
impl service::rooms::threads::Data for KeyValueDatabase {
|
impl service::rooms::threads::Data for KeyValueDatabase {
|
||||||
fn threads_until<'a>(
|
fn threads_until<'a>(
|
||||||
&'a self,
|
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
|
||||||
user_id: &'a UserId,
|
|
||||||
room_id: &'a RoomId,
|
|
||||||
until: u64,
|
|
||||||
_include: &'a IncludeThreads,
|
|
||||||
) -> PduEventIterResult<'a> {
|
) -> PduEventIterResult<'a> {
|
||||||
let prefix = services()
|
let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec();
|
||||||
.rooms
|
|
||||||
.short
|
|
||||||
.get_shortroomid(room_id)?
|
|
||||||
.expect("room exists")
|
|
||||||
.to_be_bytes()
|
|
||||||
.to_vec();
|
|
||||||
|
|
||||||
let mut current = prefix.clone();
|
let mut current = prefix.clone();
|
||||||
current.extend_from_slice(&(until - 1).to_be_bytes());
|
current.extend_from_slice(&(until - 1).to_be_bytes());
|
||||||
|
|
||||||
Ok(Box::new(
|
Ok(Box::new(
|
||||||
self.threadid_userids
|
self.threadid_userids.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||||
.iter_from(¤t, true)
|
move |(pduid, _users)| {
|
||||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
|
||||||
.map(move |(pduid, _users)| {
|
|
||||||
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
|
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
|
||||||
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
|
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
|
||||||
let mut pdu = services()
|
let mut pdu = services()
|
||||||
.rooms
|
.rooms
|
||||||
.timeline
|
.timeline
|
||||||
.get_pdu_from_id(&pduid)?
|
.get_pdu_from_id(&pduid)?
|
||||||
.ok_or_else(|| {
|
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
|
||||||
Error::bad_database("Invalid pduid reference in threadid_userids")
|
|
||||||
})?;
|
|
||||||
if pdu.sender != user_id {
|
if pdu.sender != user_id {
|
||||||
pdu.remove_transaction_id()?;
|
pdu.remove_transaction_id()?;
|
||||||
}
|
}
|
||||||
Ok((count, pdu))
|
Ok((count, pdu))
|
||||||
}),
|
},
|
||||||
|
),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
|
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
|
||||||
let users = participants
|
let users = participants.iter().map(|user| user.as_bytes()).collect::<Vec<_>>().join(&[0xFF][..]);
|
||||||
.iter()
|
|
||||||
.map(|user| user.as_bytes())
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
.join(&[0xff][..]);
|
|
||||||
|
|
||||||
self.threadid_userids.insert(root_id, &users)?;
|
self.threadid_userids.insert(root_id, &users)?;
|
||||||
|
|
||||||
|
@ -63,11 +46,12 @@ impl service::rooms::threads::Data for KeyValueDatabase {
|
||||||
if let Some(users) = self.threadid_userids.get(root_id)? {
|
if let Some(users) = self.threadid_userids.get(root_id)? {
|
||||||
Ok(Some(
|
Ok(Some(
|
||||||
users
|
users
|
||||||
.split(|b| *b == 0xff)
|
.split(|b| *b == 0xFF)
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
UserId::parse(utils::string_from_bytes(bytes).map_err(|_| {
|
UserId::parse(
|
||||||
Error::bad_database("Invalid UserId bytes in threadid_userids.")
|
utils::string_from_bytes(bytes)
|
||||||
})?)
|
.map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
|
||||||
|
)
|
||||||
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
|
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
|
||||||
})
|
})
|
||||||
.filter_map(std::result::Result::ok)
|
.filter_map(std::result::Result::ok)
|
||||||
|
|
|
@ -1,48 +1,34 @@
|
||||||
use std::{collections::hash_map, mem::size_of, sync::Arc};
|
use std::{collections::hash_map, mem::size_of, sync::Arc};
|
||||||
|
|
||||||
use ruma::{
|
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
|
||||||
api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId,
|
use service::rooms::timeline::PduCount;
|
||||||
};
|
|
||||||
use tracing::error;
|
use tracing::error;
|
||||||
|
|
||||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||||
|
|
||||||
use service::rooms::timeline::PduCount;
|
|
||||||
|
|
||||||
impl service::rooms::timeline::Data for KeyValueDatabase {
|
impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
|
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
|
||||||
match self
|
match self.lasttimelinecount_cache.lock().unwrap().entry(room_id.to_owned()) {
|
||||||
.lasttimelinecount_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.entry(room_id.to_owned())
|
|
||||||
{
|
|
||||||
hash_map::Entry::Vacant(v) => {
|
hash_map::Entry::Vacant(v) => {
|
||||||
if let Some(last_count) = self
|
if let Some(last_count) = self.pdus_until(sender_user, room_id, PduCount::max())?.find_map(|r| {
|
||||||
.pdus_until(sender_user, room_id, PduCount::max())?
|
|
||||||
.find_map(|r| {
|
|
||||||
// Filter out buggy events
|
// Filter out buggy events
|
||||||
if r.is_err() {
|
if r.is_err() {
|
||||||
error!("Bad pdu in pdus_since: {:?}", r);
|
error!("Bad pdu in pdus_since: {:?}", r);
|
||||||
}
|
}
|
||||||
r.ok()
|
r.ok()
|
||||||
})
|
}) {
|
||||||
{
|
|
||||||
Ok(*v.insert(last_count.0))
|
Ok(*v.insert(last_count.0))
|
||||||
} else {
|
} else {
|
||||||
Ok(PduCount::Normal(0))
|
Ok(PduCount::Normal(0))
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
hash_map::Entry::Occupied(o) => Ok(*o.get()),
|
hash_map::Entry::Occupied(o) => Ok(*o.get()),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the `count` of this pdu's id.
|
/// Returns the `count` of this pdu's id.
|
||||||
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
|
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
|
||||||
self.eventid_pduid
|
self.eventid_pduid.get(event_id.as_bytes())?.map(|pdu_id| pdu_count(&pdu_id)).transpose()
|
||||||
.get(event_id.as_bytes())?
|
|
||||||
.map(|pdu_id| pdu_count(&pdu_id))
|
|
||||||
.transpose()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the json of a pdu.
|
/// Returns the json of a pdu.
|
||||||
|
@ -51,10 +37,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
|| {
|
|| {
|
||||||
self.eventid_outlierpdu
|
self.eventid_outlierpdu
|
||||||
.get(event_id.as_bytes())?
|
.get(event_id.as_bytes())?
|
||||||
.map(|pdu| {
|
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||||
serde_json::from_slice(&pdu)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))
|
|
||||||
})
|
|
||||||
.transpose()
|
.transpose()
|
||||||
},
|
},
|
||||||
|x| Ok(Some(x)),
|
|x| Ok(Some(x)),
|
||||||
|
@ -66,35 +49,25 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
self.eventid_pduid
|
self.eventid_pduid
|
||||||
.get(event_id.as_bytes())?
|
.get(event_id.as_bytes())?
|
||||||
.map(|pduid| {
|
.map(|pduid| {
|
||||||
self.pduid_pdu
|
self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||||
.get(&pduid)?
|
|
||||||
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
|
||||||
})
|
})
|
||||||
.transpose()?
|
.transpose()?
|
||||||
.map(|pdu| {
|
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
|
||||||
})
|
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the pdu's id.
|
/// Returns the pdu's id.
|
||||||
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
|
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid.get(event_id.as_bytes()) }
|
||||||
self.eventid_pduid.get(event_id.as_bytes())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the pdu.
|
/// Returns the pdu.
|
||||||
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||||
self.eventid_pduid
|
self.eventid_pduid
|
||||||
.get(event_id.as_bytes())?
|
.get(event_id.as_bytes())?
|
||||||
.map(|pduid| {
|
.map(|pduid| {
|
||||||
self.pduid_pdu
|
self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||||
.get(&pduid)?
|
|
||||||
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
|
||||||
})
|
})
|
||||||
.transpose()?
|
.transpose()?
|
||||||
.map(|pdu| {
|
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
|
||||||
})
|
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -112,20 +85,14 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
|| {
|
|| {
|
||||||
self.eventid_outlierpdu
|
self.eventid_outlierpdu
|
||||||
.get(event_id.as_bytes())?
|
.get(event_id.as_bytes())?
|
||||||
.map(|pdu| {
|
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||||
serde_json::from_slice(&pdu)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))
|
|
||||||
})
|
|
||||||
.transpose()
|
.transpose()
|
||||||
},
|
},
|
||||||
|x| Ok(Some(x)),
|
|x| Ok(Some(x)),
|
||||||
)?
|
)?
|
||||||
.map(Arc::new)
|
.map(Arc::new)
|
||||||
{
|
{
|
||||||
self.pdu_cache
|
self.pdu_cache.lock().unwrap().insert(event_id.to_owned(), Arc::clone(&pdu));
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(event_id.to_owned(), Arc::clone(&pdu));
|
|
||||||
Ok(Some(pdu))
|
Ok(Some(pdu))
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
|
@ -138,8 +105,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
|
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
|
||||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||||
Ok(Some(
|
Ok(Some(
|
||||||
serde_json::from_slice(&pdu)
|
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -148,28 +114,18 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
|
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
|
||||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||||
Ok(Some(
|
Ok(Some(
|
||||||
serde_json::from_slice(&pdu)
|
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
|
||||||
))
|
))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn append_pdu(
|
fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> {
|
||||||
&self,
|
|
||||||
pdu_id: &[u8],
|
|
||||||
pdu: &PduEvent,
|
|
||||||
json: &CanonicalJsonObject,
|
|
||||||
count: u64,
|
|
||||||
) -> Result<()> {
|
|
||||||
self.pduid_pdu.insert(
|
self.pduid_pdu.insert(
|
||||||
pdu_id,
|
pdu_id,
|
||||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
self.lasttimelinecount_cache
|
self.lasttimelinecount_cache.lock().unwrap().insert(pdu.room_id.clone(), PduCount::Normal(count));
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.insert(pdu.room_id.clone(), PduCount::Normal(count));
|
|
||||||
|
|
||||||
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
|
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
|
||||||
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
|
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
|
||||||
|
@ -177,12 +133,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn prepend_backfill_pdu(
|
fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> {
|
||||||
&self,
|
|
||||||
pdu_id: &[u8],
|
|
||||||
event_id: &EventId,
|
|
||||||
json: &CanonicalJsonObject,
|
|
||||||
) -> Result<()> {
|
|
||||||
self.pduid_pdu.insert(
|
self.pduid_pdu.insert(
|
||||||
pdu_id,
|
pdu_id,
|
||||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||||
|
@ -195,49 +146,34 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Removes a pdu and creates a new one with the same id.
|
/// Removes a pdu and creates a new one with the same id.
|
||||||
fn replace_pdu(
|
fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> {
|
||||||
&self,
|
|
||||||
pdu_id: &[u8],
|
|
||||||
pdu_json: &CanonicalJsonObject,
|
|
||||||
pdu: &PduEvent,
|
|
||||||
) -> Result<()> {
|
|
||||||
if self.pduid_pdu.get(pdu_id)?.is_some() {
|
if self.pduid_pdu.get(pdu_id)?.is_some() {
|
||||||
self.pduid_pdu.insert(
|
self.pduid_pdu.insert(
|
||||||
pdu_id,
|
pdu_id,
|
||||||
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
|
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
|
||||||
)?;
|
)?;
|
||||||
} else {
|
} else {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
|
||||||
ErrorKind::NotFound,
|
|
||||||
"PDU does not exist.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
self.pdu_cache
|
self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned());
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.remove(&(*pdu.event_id).to_owned());
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an iterator over all events and their tokens in a room that happened before the
|
/// Returns an iterator over all events and their tokens in a room that
|
||||||
/// event with id `until` in reverse-chronological order.
|
/// happened before the event with id `until` in reverse-chronological
|
||||||
|
/// order.
|
||||||
fn pdus_until<'a>(
|
fn pdus_until<'a>(
|
||||||
&'a self,
|
&'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
|
||||||
user_id: &UserId,
|
|
||||||
room_id: &RoomId,
|
|
||||||
until: PduCount,
|
|
||||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||||
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
|
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
|
||||||
|
|
||||||
let user_id = user_id.to_owned();
|
let user_id = user_id.to_owned();
|
||||||
|
|
||||||
Ok(Box::new(
|
Ok(Box::new(
|
||||||
self.pduid_pdu
|
self.pduid_pdu.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||||
.iter_from(¤t, true)
|
move |(pdu_id, v)| {
|
||||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
|
||||||
.map(move |(pdu_id, v)| {
|
|
||||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||||
if pdu.sender != user_id {
|
if pdu.sender != user_id {
|
||||||
|
@ -246,25 +182,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
pdu.add_age()?;
|
pdu.add_age()?;
|
||||||
let count = pdu_count(&pdu_id)?;
|
let count = pdu_count(&pdu_id)?;
|
||||||
Ok((count, pdu))
|
Ok((count, pdu))
|
||||||
}),
|
},
|
||||||
|
),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn pdus_after<'a>(
|
fn pdus_after<'a>(
|
||||||
&'a self,
|
&'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
|
||||||
user_id: &UserId,
|
|
||||||
room_id: &RoomId,
|
|
||||||
from: PduCount,
|
|
||||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||||
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
|
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
|
||||||
|
|
||||||
let user_id = user_id.to_owned();
|
let user_id = user_id.to_owned();
|
||||||
|
|
||||||
Ok(Box::new(
|
Ok(Box::new(
|
||||||
self.pduid_pdu
|
self.pduid_pdu.iter_from(¤t, false).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||||
.iter_from(¤t, false)
|
move |(pdu_id, v)| {
|
||||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
|
||||||
.map(move |(pdu_id, v)| {
|
|
||||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||||
if pdu.sender != user_id {
|
if pdu.sender != user_id {
|
||||||
|
@ -273,35 +205,31 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
pdu.add_age()?;
|
pdu.add_age()?;
|
||||||
let count = pdu_count(&pdu_id)?;
|
let count = pdu_count(&pdu_id)?;
|
||||||
Ok((count, pdu))
|
Ok((count, pdu))
|
||||||
}),
|
},
|
||||||
|
),
|
||||||
))
|
))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn increment_notification_counts(
|
fn increment_notification_counts(
|
||||||
&self,
|
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
|
||||||
room_id: &RoomId,
|
|
||||||
notifies: Vec<OwnedUserId>,
|
|
||||||
highlights: Vec<OwnedUserId>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut notifies_batch = Vec::new();
|
let mut notifies_batch = Vec::new();
|
||||||
let mut highlights_batch = Vec::new();
|
let mut highlights_batch = Vec::new();
|
||||||
for user in notifies {
|
for user in notifies {
|
||||||
let mut userroom_id = user.as_bytes().to_vec();
|
let mut userroom_id = user.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
notifies_batch.push(userroom_id);
|
notifies_batch.push(userroom_id);
|
||||||
}
|
}
|
||||||
for user in highlights {
|
for user in highlights {
|
||||||
let mut userroom_id = user.as_bytes().to_vec();
|
let mut userroom_id = user.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
highlights_batch.push(userroom_id);
|
highlights_batch.push(userroom_id);
|
||||||
}
|
}
|
||||||
|
|
||||||
self.userroomid_notificationcount
|
self.userroomid_notificationcount.increment_batch(&mut notifies_batch.into_iter())?;
|
||||||
.increment_batch(&mut notifies_batch.into_iter())?;
|
self.userroomid_highlightcount.increment_batch(&mut highlights_batch.into_iter())?;
|
||||||
self.userroomid_highlightcount
|
|
||||||
.increment_batch(&mut highlights_batch.into_iter())?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -310,9 +238,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||||
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
|
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
|
||||||
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
|
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
|
||||||
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
|
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
|
||||||
let second_last_u64 = utils::u64_from_bytes(
|
let second_last_u64 =
|
||||||
&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()],
|
utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()]);
|
||||||
);
|
|
||||||
|
|
||||||
if matches!(second_last_u64, Ok(0)) {
|
if matches!(second_last_u64, Ok(0)) {
|
||||||
Ok(PduCount::Backfilled(u64::MAX - last_u64))
|
Ok(PduCount::Backfilled(u64::MAX - last_u64))
|
||||||
|
@ -321,12 +248,7 @@ fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
fn count_to_id(
|
fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||||
room_id: &RoomId,
|
|
||||||
count: PduCount,
|
|
||||||
offset: u64,
|
|
||||||
subtract: bool,
|
|
||||||
) -> Result<(Vec<u8>, Vec<u8>)> {
|
|
||||||
let prefix = services()
|
let prefix = services()
|
||||||
.rooms
|
.rooms
|
||||||
.short
|
.short
|
||||||
|
@ -343,7 +265,7 @@ fn count_to_id(
|
||||||
} else {
|
} else {
|
||||||
x + offset
|
x + offset
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
PduCount::Backfilled(x) => {
|
PduCount::Backfilled(x) => {
|
||||||
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
|
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
|
||||||
let num = u64::MAX - x;
|
let num = u64::MAX - x;
|
||||||
|
@ -356,7 +278,7 @@ fn count_to_id(
|
||||||
} else {
|
} else {
|
||||||
num + offset
|
num + offset
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
|
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
|
||||||
|
|
||||||
|
|
|
@ -5,95 +5,73 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}
|
||||||
impl service::rooms::user::Data for KeyValueDatabase {
|
impl service::rooms::user::Data for KeyValueDatabase {
|
||||||
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||||
roomuser_id.push(0xff);
|
roomuser_id.push(0xFF);
|
||||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
self.userroomid_notificationcount
|
self.userroomid_notificationcount.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||||
.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
self.userroomid_highlightcount.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||||
self.userroomid_highlightcount
|
|
||||||
.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
|
||||||
|
|
||||||
self.roomuserid_lastnotificationread.insert(
|
self.roomuserid_lastnotificationread.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
|
||||||
&roomuser_id,
|
|
||||||
&services().globals.next_count()?.to_be_bytes(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
self.userroomid_notificationcount
|
self.userroomid_notificationcount
|
||||||
.get(&userroom_id)?
|
.get(&userroom_id)?
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
utils::u64_from_bytes(&bytes)
|
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
|
||||||
.map_err(|_| Error::bad_database("Invalid notification count in db."))
|
|
||||||
})
|
})
|
||||||
.unwrap_or(Ok(0))
|
.unwrap_or(Ok(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||||
userroom_id.push(0xff);
|
userroom_id.push(0xFF);
|
||||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||||
|
|
||||||
self.userroomid_highlightcount
|
self.userroomid_highlightcount
|
||||||
.get(&userroom_id)?
|
.get(&userroom_id)?
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
utils::u64_from_bytes(&bytes)
|
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
|
||||||
.map_err(|_| Error::bad_database("Invalid highlight count in db."))
|
|
||||||
})
|
})
|
||||||
.unwrap_or(Ok(0))
|
.unwrap_or(Ok(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||||
let mut key = room_id.as_bytes().to_vec();
|
let mut key = room_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(user_id.as_bytes());
|
key.extend_from_slice(user_id.as_bytes());
|
||||||
|
|
||||||
Ok(self
|
Ok(self
|
||||||
.roomuserid_lastnotificationread
|
.roomuserid_lastnotificationread
|
||||||
.get(&key)?
|
.get(&key)?
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
utils::u64_from_bytes(&bytes)
|
||||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.transpose()?
|
.transpose()?
|
||||||
.unwrap_or(0))
|
.unwrap_or(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn associate_token_shortstatehash(
|
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
|
||||||
&self,
|
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
|
||||||
room_id: &RoomId,
|
|
||||||
token: u64,
|
|
||||||
shortstatehash: u64,
|
|
||||||
) -> Result<()> {
|
|
||||||
let shortroomid = services()
|
|
||||||
.rooms
|
|
||||||
.short
|
|
||||||
.get_shortroomid(room_id)?
|
|
||||||
.expect("room exists");
|
|
||||||
|
|
||||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||||
key.extend_from_slice(&token.to_be_bytes());
|
key.extend_from_slice(&token.to_be_bytes());
|
||||||
|
|
||||||
self.roomsynctoken_shortstatehash
|
self.roomsynctoken_shortstatehash.insert(&key, &shortstatehash.to_be_bytes())
|
||||||
.insert(&key, &shortstatehash.to_be_bytes())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
|
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
|
||||||
let shortroomid = services()
|
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
|
||||||
.rooms
|
|
||||||
.short
|
|
||||||
.get_shortroomid(room_id)?
|
|
||||||
.expect("room exists");
|
|
||||||
|
|
||||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||||
key.extend_from_slice(&token.to_be_bytes());
|
key.extend_from_slice(&token.to_be_bytes());
|
||||||
|
@ -101,20 +79,18 @@ impl service::rooms::user::Data for KeyValueDatabase {
|
||||||
self.roomsynctoken_shortstatehash
|
self.roomsynctoken_shortstatehash
|
||||||
.get(&key)?
|
.get(&key)?
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
utils::u64_from_bytes(&bytes)
|
||||||
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")
|
.map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.transpose()
|
.transpose()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_shared_rooms<'a>(
|
fn get_shared_rooms<'a>(
|
||||||
&'a self,
|
&'a self, users: Vec<OwnedUserId>,
|
||||||
users: Vec<OwnedUserId>,
|
|
||||||
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
|
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
|
||||||
let iterators = users.into_iter().map(move |user_id| {
|
let iterators = users.into_iter().map(move |user_id| {
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
self.userroomid_joined
|
self.userroomid_joined
|
||||||
.scan_prefix(prefix)
|
.scan_prefix(prefix)
|
||||||
|
@ -122,10 +98,9 @@ impl service::rooms::user::Data for KeyValueDatabase {
|
||||||
let roomid_index = key
|
let roomid_index = key
|
||||||
.iter()
|
.iter()
|
||||||
.enumerate()
|
.enumerate()
|
||||||
.find(|(_, &b)| b == 0xff)
|
.find(|(_, &b)| b == 0xFF)
|
||||||
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
|
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
|
||||||
.0
|
.0 + 1; // +1 because the room id starts AFTER the separator
|
||||||
+ 1; // +1 because the room id starts AFTER the separator
|
|
||||||
|
|
||||||
let room_id = key[roomid_index..].to_vec();
|
let room_id = key[roomid_index..].to_vec();
|
||||||
|
|
||||||
|
@ -134,14 +109,14 @@ impl service::rooms::user::Data for KeyValueDatabase {
|
||||||
.filter_map(std::result::Result::ok)
|
.filter_map(std::result::Result::ok)
|
||||||
});
|
});
|
||||||
|
|
||||||
// We use the default compare function because keys are sorted correctly (not reversed)
|
// We use the default compare function because keys are sorted correctly (not
|
||||||
|
// reversed)
|
||||||
Ok(Box::new(
|
Ok(Box::new(
|
||||||
utils::common_elements(iterators, Ord::cmp)
|
utils::common_elements(iterators, Ord::cmp).expect("users is not empty").map(|bytes| {
|
||||||
.expect("users is not empty")
|
RoomId::parse(
|
||||||
.map(|bytes| {
|
utils::string_from_bytes(&bytes)
|
||||||
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
|
||||||
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
|
)
|
||||||
})?)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
|
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
|
||||||
}),
|
}),
|
||||||
))
|
))
|
||||||
|
|
|
@ -21,8 +21,7 @@ impl service::sending::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn active_requests_for<'a>(
|
fn active_requests_for<'a>(
|
||||||
&'a self,
|
&'a self, outgoing_kind: &OutgoingKind,
|
||||||
outgoing_kind: &OutgoingKind,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> {
|
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> {
|
||||||
let prefix = outgoing_kind.get_prefix();
|
let prefix = outgoing_kind.get_prefix();
|
||||||
Box::new(
|
Box::new(
|
||||||
|
@ -32,9 +31,7 @@ impl service::sending::Data for KeyValueDatabase {
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> {
|
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { self.servercurrentevent_data.remove(&key) }
|
||||||
self.servercurrentevent_data.remove(&key)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||||
let prefix = outgoing_kind.get_prefix();
|
let prefix = outgoing_kind.get_prefix();
|
||||||
|
@ -58,10 +55,7 @@ impl service::sending::Data for KeyValueDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn queue_requests(
|
fn queue_requests(&self, requests: &[(&OutgoingKind, SendingEventType)]) -> Result<Vec<Vec<u8>>> {
|
||||||
&self,
|
|
||||||
requests: &[(&OutgoingKind, SendingEventType)],
|
|
||||||
) -> Result<Vec<Vec<u8>>> {
|
|
||||||
let mut batch = Vec::new();
|
let mut batch = Vec::new();
|
||||||
let mut keys = Vec::new();
|
let mut keys = Vec::new();
|
||||||
for (outgoing_kind, event) in requests {
|
for (outgoing_kind, event) in requests {
|
||||||
|
@ -79,14 +73,12 @@ impl service::sending::Data for KeyValueDatabase {
|
||||||
batch.push((key.clone(), value.to_owned()));
|
batch.push((key.clone(), value.to_owned()));
|
||||||
keys.push(key);
|
keys.push(key);
|
||||||
}
|
}
|
||||||
self.servernameevent_data
|
self.servernameevent_data.insert_batch(&mut batch.into_iter())?;
|
||||||
.insert_batch(&mut batch.into_iter())?;
|
|
||||||
Ok(keys)
|
Ok(keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn queued_requests<'a>(
|
fn queued_requests<'a>(
|
||||||
&'a self,
|
&'a self, outgoing_kind: &OutgoingKind,
|
||||||
outgoing_kind: &OutgoingKind,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> {
|
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> {
|
||||||
let prefix = outgoing_kind.get_prefix();
|
let prefix = outgoing_kind.get_prefix();
|
||||||
return Box::new(
|
return Box::new(
|
||||||
|
@ -111,37 +103,27 @@ impl service::sending::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
|
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
|
||||||
self.servername_educount
|
self.servername_educount.insert(server_name.as_bytes(), &last_count.to_be_bytes())
|
||||||
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
|
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
|
||||||
self.servername_educount
|
self.servername_educount.get(server_name.as_bytes())?.map_or(Ok(0), |bytes| {
|
||||||
.get(server_name.as_bytes())?
|
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
|
||||||
.map_or(Ok(0), |bytes| {
|
|
||||||
utils::u64_from_bytes(&bytes)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(key))]
|
#[tracing::instrument(skip(key))]
|
||||||
fn parse_servercurrentevent(
|
fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(OutgoingKind, SendingEventType)> {
|
||||||
key: &[u8],
|
|
||||||
value: Vec<u8>,
|
|
||||||
) -> Result<(OutgoingKind, SendingEventType)> {
|
|
||||||
// Appservices start with a plus
|
// Appservices start with a plus
|
||||||
Ok::<_, Error>(if key.starts_with(b"+") {
|
Ok::<_, Error>(if key.starts_with(b"+") {
|
||||||
let mut parts = key[1..].splitn(2, |&b| b == 0xff);
|
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
|
||||||
|
|
||||||
let server = parts.next().expect("splitn always returns one element");
|
let server = parts.next().expect("splitn always returns one element");
|
||||||
let event = parts
|
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||||
.next()
|
|
||||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
|
||||||
|
|
||||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
let server = utils::string_from_bytes(server)
|
||||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
|
||||||
})?;
|
|
||||||
|
|
||||||
(
|
(
|
||||||
OutgoingKind::Appservice(server),
|
OutgoingKind::Appservice(server),
|
||||||
|
@ -152,23 +134,19 @@ fn parse_servercurrentevent(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
} else if key.starts_with(b"$") {
|
} else if key.starts_with(b"$") {
|
||||||
let mut parts = key[1..].splitn(3, |&b| b == 0xff);
|
let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
|
||||||
|
|
||||||
let user = parts.next().expect("splitn always returns one element");
|
let user = parts.next().expect("splitn always returns one element");
|
||||||
let user_string = utils::string_from_bytes(user)
|
let user_string = utils::string_from_bytes(user)
|
||||||
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
|
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
|
||||||
let user_id = UserId::parse(user_string)
|
let user_id =
|
||||||
.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
||||||
|
|
||||||
let pushkey = parts
|
let pushkey = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||||
.next()
|
|
||||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
|
||||||
let pushkey_string = utils::string_from_bytes(pushkey)
|
let pushkey_string = utils::string_from_bytes(pushkey)
|
||||||
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
|
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
|
||||||
|
|
||||||
let event = parts
|
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||||
.next()
|
|
||||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
|
||||||
|
|
||||||
(
|
(
|
||||||
OutgoingKind::Push(user_id, pushkey_string),
|
OutgoingKind::Push(user_id, pushkey_string),
|
||||||
|
@ -180,21 +158,19 @@ fn parse_servercurrentevent(
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
let mut parts = key.splitn(2, |&b| b == 0xff);
|
let mut parts = key.splitn(2, |&b| b == 0xFF);
|
||||||
|
|
||||||
let server = parts.next().expect("splitn always returns one element");
|
let server = parts.next().expect("splitn always returns one element");
|
||||||
let event = parts
|
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||||
.next()
|
|
||||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
|
||||||
|
|
||||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
let server = utils::string_from_bytes(server)
|
||||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
|
||||||
})?;
|
|
||||||
|
|
||||||
(
|
(
|
||||||
OutgoingKind::Normal(ServerName::parse(server).map_err(|_| {
|
OutgoingKind::Normal(
|
||||||
Error::bad_database("Invalid server string in server_currenttransaction")
|
ServerName::parse(server)
|
||||||
})?),
|
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
|
||||||
|
),
|
||||||
if value.is_empty() {
|
if value.is_empty() {
|
||||||
SendingEventType::Pdu(event.to_vec())
|
SendingEventType::Pdu(event.to_vec())
|
||||||
} else {
|
} else {
|
||||||
|
|
|
@ -4,16 +4,12 @@ use crate::{database::KeyValueDatabase, service, Result};
|
||||||
|
|
||||||
impl service::transaction_ids::Data for KeyValueDatabase {
|
impl service::transaction_ids::Data for KeyValueDatabase {
|
||||||
fn add_txnid(
|
fn add_txnid(
|
||||||
&self,
|
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
|
||||||
user_id: &UserId,
|
|
||||||
device_id: Option<&DeviceId>,
|
|
||||||
txn_id: &TransactionId,
|
|
||||||
data: &[u8],
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(txn_id.as_bytes());
|
key.extend_from_slice(txn_id.as_bytes());
|
||||||
|
|
||||||
self.userdevicetxnid_response.insert(&key, data)?;
|
self.userdevicetxnid_response.insert(&key, data)?;
|
||||||
|
@ -22,15 +18,12 @@ impl service::transaction_ids::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn existing_txnid(
|
fn existing_txnid(
|
||||||
&self,
|
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
|
||||||
user_id: &UserId,
|
|
||||||
device_id: Option<&DeviceId>,
|
|
||||||
txn_id: &TransactionId,
|
|
||||||
) -> Result<Option<Vec<u8>>> {
|
) -> Result<Option<Vec<u8>>> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(txn_id.as_bytes());
|
key.extend_from_slice(txn_id.as_bytes());
|
||||||
|
|
||||||
// If there's no entry, this is a new transaction
|
// If there's no entry, this is a new transaction
|
||||||
|
|
|
@ -7,16 +7,9 @@ use crate::{database::KeyValueDatabase, service, Error, Result};
|
||||||
|
|
||||||
impl service::uiaa::Data for KeyValueDatabase {
|
impl service::uiaa::Data for KeyValueDatabase {
|
||||||
fn set_uiaa_request(
|
fn set_uiaa_request(
|
||||||
&self,
|
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
session: &str,
|
|
||||||
request: &CanonicalJsonValue,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
self.userdevicesessionid_uiaarequest
|
self.userdevicesessionid_uiaarequest.write().unwrap().insert(
|
||||||
.write()
|
|
||||||
.unwrap()
|
|
||||||
.insert(
|
|
||||||
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
|
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
|
||||||
request.to_owned(),
|
request.to_owned(),
|
||||||
);
|
);
|
||||||
|
@ -24,12 +17,7 @@ impl service::uiaa::Data for KeyValueDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_uiaa_request(
|
fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
session: &str,
|
|
||||||
) -> Option<CanonicalJsonValue> {
|
|
||||||
self.userdevicesessionid_uiaarequest
|
self.userdevicesessionid_uiaarequest
|
||||||
.read()
|
.read()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -38,16 +26,12 @@ impl service::uiaa::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_uiaa_session(
|
fn update_uiaa_session(
|
||||||
&self,
|
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
session: &str,
|
|
||||||
uiaainfo: Option<&UiaaInfo>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||||
userdevicesessionid.push(0xff);
|
userdevicesessionid.push(0xFF);
|
||||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||||
userdevicesessionid.push(0xff);
|
userdevicesessionid.push(0xFF);
|
||||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||||
|
|
||||||
if let Some(uiaainfo) = uiaainfo {
|
if let Some(uiaainfo) = uiaainfo {
|
||||||
|
@ -56,33 +40,24 @@ impl service::uiaa::Data for KeyValueDatabase {
|
||||||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
|
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
|
||||||
)?;
|
)?;
|
||||||
} else {
|
} else {
|
||||||
self.userdevicesessionid_uiaainfo
|
self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?;
|
||||||
.remove(&userdevicesessionid)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_uiaa_session(
|
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
session: &str,
|
|
||||||
) -> Result<UiaaInfo> {
|
|
||||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||||
userdevicesessionid.push(0xff);
|
userdevicesessionid.push(0xFF);
|
||||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||||
userdevicesessionid.push(0xff);
|
userdevicesessionid.push(0xFF);
|
||||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||||
|
|
||||||
serde_json::from_slice(
|
serde_json::from_slice(
|
||||||
&self
|
&self
|
||||||
.userdevicesessionid_uiaainfo
|
.userdevicesessionid_uiaainfo
|
||||||
.get(&userdevicesessionid)?
|
.get(&userdevicesessionid)?
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::Forbidden, "UIAA session does not exist."))?,
|
||||||
ErrorKind::Forbidden,
|
|
||||||
"UIAA session does not exist.",
|
|
||||||
))?,
|
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
|
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,8 +5,8 @@ use ruma::{
|
||||||
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
|
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
|
||||||
events::{AnyToDeviceEvent, StateEventType},
|
events::{AnyToDeviceEvent, StateEventType},
|
||||||
serde::Raw,
|
serde::Raw,
|
||||||
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
|
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId,
|
||||||
OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId,
|
OwnedMxcUri, OwnedUserId, UInt, UserId,
|
||||||
};
|
};
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
||||||
|
@ -18,50 +18,37 @@ use crate::{
|
||||||
|
|
||||||
impl service::users::Data for KeyValueDatabase {
|
impl service::users::Data for KeyValueDatabase {
|
||||||
/// Check if a user has an account on this homeserver.
|
/// Check if a user has an account on this homeserver.
|
||||||
fn exists(&self, user_id: &UserId) -> Result<bool> {
|
fn exists(&self, user_id: &UserId) -> Result<bool> { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) }
|
||||||
Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Check if account is deactivated
|
/// Check if account is deactivated
|
||||||
fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
|
fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
|
||||||
Ok(self
|
Ok(self
|
||||||
.userid_password
|
.userid_password
|
||||||
.get(user_id.as_bytes())?
|
.get(user_id.as_bytes())?
|
||||||
.ok_or(Error::BadRequest(
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))?
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"User does not exist.",
|
|
||||||
))?
|
|
||||||
.is_empty())
|
.is_empty())
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns the number of users registered on this server.
|
/// Returns the number of users registered on this server.
|
||||||
fn count(&self) -> Result<usize> {
|
fn count(&self) -> Result<usize> { Ok(self.userid_password.iter().count()) }
|
||||||
Ok(self.userid_password.iter().count())
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Find out which user an access token belongs to.
|
/// Find out which user an access token belongs to.
|
||||||
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> {
|
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> {
|
||||||
self.token_userdeviceid
|
self.token_userdeviceid.get(token.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||||
.get(token.as_bytes())?
|
let mut parts = bytes.split(|&b| b == 0xFF);
|
||||||
.map_or(Ok(None), |bytes| {
|
let user_bytes =
|
||||||
let mut parts = bytes.split(|&b| b == 0xff);
|
parts.next().ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?;
|
||||||
let user_bytes = parts.next().ok_or_else(|| {
|
let device_bytes =
|
||||||
Error::bad_database("User ID in token_userdeviceid is invalid.")
|
parts.next().ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?;
|
||||||
})?;
|
|
||||||
let device_bytes = parts.next().ok_or_else(|| {
|
|
||||||
Error::bad_database("Device ID in token_userdeviceid is invalid.")
|
|
||||||
})?;
|
|
||||||
|
|
||||||
Ok(Some((
|
Ok(Some((
|
||||||
UserId::parse(utils::string_from_bytes(user_bytes).map_err(|_| {
|
UserId::parse(
|
||||||
Error::bad_database("User ID in token_userdeviceid is invalid unicode.")
|
utils::string_from_bytes(user_bytes)
|
||||||
})?)
|
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?,
|
||||||
.map_err(|_| {
|
)
|
||||||
Error::bad_database("User ID in token_userdeviceid is invalid.")
|
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?,
|
||||||
})?,
|
utils::string_from_bytes(device_bytes)
|
||||||
utils::string_from_bytes(device_bytes).map_err(|_| {
|
.map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?,
|
||||||
Error::bad_database("Device ID in token_userdeviceid is invalid.")
|
|
||||||
})?,
|
|
||||||
)))
|
)))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
@ -69,16 +56,18 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
/// Returns an iterator over all users on this homeserver.
|
/// Returns an iterator over all users on this homeserver.
|
||||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||||
Box::new(self.userid_password.iter().map(|(bytes, _)| {
|
Box::new(self.userid_password.iter().map(|(bytes, _)| {
|
||||||
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
UserId::parse(
|
||||||
Error::bad_database("User ID in userid_password is invalid unicode.")
|
utils::string_from_bytes(&bytes)
|
||||||
})?)
|
.map_err(|_| Error::bad_database("User ID in userid_password is invalid unicode."))?,
|
||||||
|
)
|
||||||
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))
|
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))
|
||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns a list of local users as list of usernames.
|
/// Returns a list of local users as list of usernames.
|
||||||
///
|
///
|
||||||
/// A user account is considered `local` if the length of it's password is greater then zero.
|
/// A user account is considered `local` if the length of it's password is
|
||||||
|
/// greater then zero.
|
||||||
fn list_local_users(&self) -> Result<Vec<String>> {
|
fn list_local_users(&self) -> Result<Vec<String>> {
|
||||||
let users: Vec<String> = self
|
let users: Vec<String> = self
|
||||||
.userid_password
|
.userid_password
|
||||||
|
@ -90,9 +79,7 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
|
|
||||||
/// Returns the password hash for the given user.
|
/// Returns the password hash for the given user.
|
||||||
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
|
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||||
self.userid_password
|
self.userid_password.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||||
.get(user_id.as_bytes())?
|
|
||||||
.map_or(Ok(None), |bytes| {
|
|
||||||
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
|
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||||
Error::bad_database("Password hash in db is not valid string.")
|
Error::bad_database("Password hash in db is not valid string.")
|
||||||
})?))
|
})?))
|
||||||
|
@ -103,8 +90,7 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
|
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
|
||||||
if let Some(password) = password {
|
if let Some(password) = password {
|
||||||
if let Ok(hash) = utils::calculate_password_hash(password) {
|
if let Ok(hash) = utils::calculate_password_hash(password) {
|
||||||
self.userid_password
|
self.userid_password.insert(user_id.as_bytes(), hash.as_bytes())?;
|
||||||
.insert(user_id.as_bytes(), hash.as_bytes())?;
|
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
Err(Error::BadRequest(
|
Err(Error::BadRequest(
|
||||||
|
@ -120,20 +106,18 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
|
|
||||||
/// Returns the displayname of a user on this homeserver.
|
/// Returns the displayname of a user on this homeserver.
|
||||||
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
|
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||||
self.userid_displayname
|
self.userid_displayname.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||||
.get(user_id.as_bytes())?
|
Ok(Some(
|
||||||
.map_or(Ok(None), |bytes| {
|
utils::string_from_bytes(&bytes).map_err(|_| Error::bad_database("Displayname in db is invalid."))?,
|
||||||
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
|
))
|
||||||
Error::bad_database("Displayname in db is invalid.")
|
|
||||||
})?))
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change.
|
/// Sets a new displayname or removes it if displayname is None. You still
|
||||||
|
/// need to nofify all rooms of this change.
|
||||||
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
|
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
|
||||||
if let Some(displayname) = displayname {
|
if let Some(displayname) = displayname {
|
||||||
self.userid_displayname
|
self.userid_displayname.insert(user_id.as_bytes(), displayname.as_bytes())?;
|
||||||
.insert(user_id.as_bytes(), displayname.as_bytes())?;
|
|
||||||
} else {
|
} else {
|
||||||
self.userid_displayname.remove(user_id.as_bytes())?;
|
self.userid_displayname.remove(user_id.as_bytes())?;
|
||||||
}
|
}
|
||||||
|
@ -159,8 +143,7 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
/// Sets a new avatar_url or removes it if avatar_url is None.
|
/// Sets a new avatar_url or removes it if avatar_url is None.
|
||||||
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> {
|
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> {
|
||||||
if let Some(avatar_url) = avatar_url {
|
if let Some(avatar_url) = avatar_url {
|
||||||
self.userid_avatarurl
|
self.userid_avatarurl.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
|
||||||
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
|
|
||||||
} else {
|
} else {
|
||||||
self.userid_avatarurl.remove(user_id.as_bytes())?;
|
self.userid_avatarurl.remove(user_id.as_bytes())?;
|
||||||
}
|
}
|
||||||
|
@ -184,8 +167,7 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
/// Sets a new avatar_url or removes it if avatar_url is None.
|
/// Sets a new avatar_url or removes it if avatar_url is None.
|
||||||
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
|
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
|
||||||
if let Some(blurhash) = blurhash {
|
if let Some(blurhash) = blurhash {
|
||||||
self.userid_blurhash
|
self.userid_blurhash.insert(user_id.as_bytes(), blurhash.as_bytes())?;
|
||||||
.insert(user_id.as_bytes(), blurhash.as_bytes())?;
|
|
||||||
} else {
|
} else {
|
||||||
self.userid_blurhash.remove(user_id.as_bytes())?;
|
self.userid_blurhash.remove(user_id.as_bytes())?;
|
||||||
}
|
}
|
||||||
|
@ -195,30 +177,20 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
|
|
||||||
/// Adds a new device to a user.
|
/// Adds a new device to a user.
|
||||||
fn create_device(
|
fn create_device(
|
||||||
&self,
|
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
token: &str,
|
|
||||||
initial_device_display_name: Option<String>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// This method should never be called for nonexistent users. We shouldn't assert though...
|
// This method should never be called for nonexistent users. We shouldn't assert
|
||||||
|
// though...
|
||||||
if !self.exists(user_id)? {
|
if !self.exists(user_id)? {
|
||||||
warn!(
|
warn!("Called create_device for non-existent user {} in database", user_id);
|
||||||
"Called create_device for non-existent user {} in database",
|
return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."));
|
||||||
user_id
|
|
||||||
);
|
|
||||||
return Err(Error::BadRequest(
|
|
||||||
ErrorKind::InvalidParam,
|
|
||||||
"User does not exist.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||||
userdeviceid.push(0xff);
|
userdeviceid.push(0xFF);
|
||||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||||
|
|
||||||
self.userid_devicelistversion
|
self.userid_devicelistversion.increment(user_id.as_bytes())?;
|
||||||
.increment(user_id.as_bytes())?;
|
|
||||||
|
|
||||||
self.userdeviceid_metadata.insert(
|
self.userdeviceid_metadata.insert(
|
||||||
&userdeviceid,
|
&userdeviceid,
|
||||||
|
@ -239,7 +211,7 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
/// Removes a device from a user.
|
/// Removes a device from a user.
|
||||||
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||||
userdeviceid.push(0xff);
|
userdeviceid.push(0xFF);
|
||||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||||
|
|
||||||
// Remove tokens
|
// Remove tokens
|
||||||
|
@ -250,7 +222,7 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
|
|
||||||
// Remove todevice events
|
// Remove todevice events
|
||||||
let mut prefix = userdeviceid.clone();
|
let mut prefix = userdeviceid.clone();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
|
for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
|
||||||
self.todeviceid_events.remove(&key)?;
|
self.todeviceid_events.remove(&key)?;
|
||||||
|
@ -258,8 +230,7 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
|
|
||||||
// TODO: Remove onetimekeys
|
// TODO: Remove onetimekeys
|
||||||
|
|
||||||
self.userid_devicelistversion
|
self.userid_devicelistversion.increment(user_id.as_bytes())?;
|
||||||
.increment(user_id.as_bytes())?;
|
|
||||||
|
|
||||||
self.userdeviceid_metadata.remove(&userdeviceid)?;
|
self.userdeviceid_metadata.remove(&userdeviceid)?;
|
||||||
|
|
||||||
|
@ -267,39 +238,34 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns an iterator over all device ids of this user.
|
/// Returns an iterator over all device ids of this user.
|
||||||
fn all_device_ids<'a>(
|
fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
|
||||||
&'a self,
|
|
||||||
user_id: &UserId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
|
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
// All devices have metadata
|
// All devices have metadata
|
||||||
Box::new(
|
Box::new(self.userdeviceid_metadata.scan_prefix(prefix).map(|(bytes, _)| {
|
||||||
self.userdeviceid_metadata
|
|
||||||
.scan_prefix(prefix)
|
|
||||||
.map(|(bytes, _)| {
|
|
||||||
Ok(utils::string_from_bytes(
|
Ok(utils::string_from_bytes(
|
||||||
bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
|
bytes
|
||||||
Error::bad_database("UserDevice ID in db is invalid.")
|
.rsplit(|&b| b == 0xFF)
|
||||||
})?,
|
.next()
|
||||||
|
.ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
|
||||||
)
|
)
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))?
|
||||||
Error::bad_database("Device ID in userdeviceid_metadata is invalid.")
|
|
||||||
})?
|
|
||||||
.into())
|
.into())
|
||||||
}),
|
}))
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Replaces the access token of one device.
|
/// Replaces the access token of one device.
|
||||||
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
|
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
|
||||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||||
userdeviceid.push(0xff);
|
userdeviceid.push(0xFF);
|
||||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||||
|
|
||||||
// should not be None, but we shouldn't assert either lol...
|
// should not be None, but we shouldn't assert either lol...
|
||||||
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
|
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
|
||||||
warn!("Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", user_id, device_id);
|
warn!(
|
||||||
|
"Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database",
|
||||||
|
user_id, device_id
|
||||||
|
);
|
||||||
return Err(Error::bad_database(
|
return Err(Error::bad_database(
|
||||||
"User does not exist or device ID has no metadata in database.",
|
"User does not exist or device ID has no metadata in database.",
|
||||||
));
|
));
|
||||||
|
@ -312,41 +278,39 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
// Assign token to user device combination
|
// Assign token to user device combination
|
||||||
self.userdeviceid_token
|
self.userdeviceid_token.insert(&userdeviceid, token.as_bytes())?;
|
||||||
.insert(&userdeviceid, token.as_bytes())?;
|
self.token_userdeviceid.insert(token.as_bytes(), &userdeviceid)?;
|
||||||
self.token_userdeviceid
|
|
||||||
.insert(token.as_bytes(), &userdeviceid)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_one_time_key(
|
fn add_one_time_key(
|
||||||
&self,
|
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId,
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
one_time_key_key: &DeviceKeyId,
|
|
||||||
one_time_key_value: &Raw<OneTimeKey>,
|
one_time_key_value: &Raw<OneTimeKey>,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(device_id.as_bytes());
|
key.extend_from_slice(device_id.as_bytes());
|
||||||
|
|
||||||
// All devices have metadata
|
// All devices have metadata
|
||||||
// Only existing devices should be able to call this, but we shouldn't assert either...
|
// Only existing devices should be able to call this, but we shouldn't assert
|
||||||
|
// either...
|
||||||
if self.userdeviceid_metadata.get(&key)?.is_none() {
|
if self.userdeviceid_metadata.get(&key)?.is_none() {
|
||||||
warn!("Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", user_id, device_id);
|
warn!(
|
||||||
|
"Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in \
|
||||||
|
database",
|
||||||
|
user_id, device_id
|
||||||
|
);
|
||||||
return Err(Error::bad_database(
|
return Err(Error::bad_database(
|
||||||
"User does not exist or device ID has no metadata in database.",
|
"User does not exist or device ID has no metadata in database.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
|
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
|
||||||
// because there are no wrapping quotation marks anymore)
|
// because there are no wrapping quotation marks anymore)
|
||||||
key.extend_from_slice(
|
key.extend_from_slice(
|
||||||
serde_json::to_string(one_time_key_key)
|
serde_json::to_string(one_time_key_key).expect("DeviceKeyId::to_string always works").as_bytes(),
|
||||||
.expect("DeviceKeyId::to_string always works")
|
|
||||||
.as_bytes(),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
self.onetimekeyid_onetimekeys.insert(
|
self.onetimekeyid_onetimekeys.insert(
|
||||||
|
@ -354,10 +318,7 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
|
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
self.userid_lastonetimekeyupdate.insert(
|
self.userid_lastonetimekeyupdate.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||||
user_id.as_bytes(),
|
|
||||||
&services().globals.next_count()?.to_be_bytes(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -366,31 +327,24 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
self.userid_lastonetimekeyupdate
|
self.userid_lastonetimekeyupdate
|
||||||
.get(user_id.as_bytes())?
|
.get(user_id.as_bytes())?
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
utils::u64_from_bytes(&bytes)
|
||||||
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
|
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
|
||||||
})
|
|
||||||
})
|
})
|
||||||
.unwrap_or(Ok(0))
|
.unwrap_or(Ok(0))
|
||||||
}
|
}
|
||||||
|
|
||||||
fn take_one_time_key(
|
fn take_one_time_key(
|
||||||
&self,
|
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
key_algorithm: &DeviceKeyAlgorithm,
|
|
||||||
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
|
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(device_id.as_bytes());
|
prefix.extend_from_slice(device_id.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.push(b'"'); // Annoying quotation mark
|
prefix.push(b'"'); // Annoying quotation mark
|
||||||
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
|
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
|
||||||
prefix.push(b':');
|
prefix.push(b':');
|
||||||
|
|
||||||
self.userid_lastonetimekeyupdate.insert(
|
self.userid_lastonetimekeyupdate.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||||
user_id.as_bytes(),
|
|
||||||
&services().globals.next_count()?.to_be_bytes(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
self.onetimekeyid_onetimekeys
|
self.onetimekeyid_onetimekeys
|
||||||
.scan_prefix(prefix)
|
.scan_prefix(prefix)
|
||||||
|
@ -400,7 +354,7 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
|
|
||||||
Ok((
|
Ok((
|
||||||
serde_json::from_slice(
|
serde_json::from_slice(
|
||||||
key.rsplit(|&b| b == 0xff)
|
key.rsplit(|&b| b == 0xFF)
|
||||||
.next()
|
.next()
|
||||||
.ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?,
|
.ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?,
|
||||||
)
|
)
|
||||||
|
@ -413,45 +367,35 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn count_one_time_keys(
|
fn count_one_time_keys(
|
||||||
&self,
|
&self, user_id: &UserId, device_id: &DeviceId,
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> {
|
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> {
|
||||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||||
userdeviceid.push(0xff);
|
userdeviceid.push(0xFF);
|
||||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||||
|
|
||||||
let mut counts = BTreeMap::new();
|
let mut counts = BTreeMap::new();
|
||||||
|
|
||||||
for algorithm in
|
for algorithm in self.onetimekeyid_onetimekeys.scan_prefix(userdeviceid).map(|(bytes, _)| {
|
||||||
self.onetimekeyid_onetimekeys
|
|
||||||
.scan_prefix(userdeviceid)
|
|
||||||
.map(|(bytes, _)| {
|
|
||||||
Ok::<_, Error>(
|
Ok::<_, Error>(
|
||||||
serde_json::from_slice::<OwnedDeviceKeyId>(
|
serde_json::from_slice::<OwnedDeviceKeyId>(
|
||||||
bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
|
bytes
|
||||||
Error::bad_database("OneTimeKey ID in db is invalid.")
|
.rsplit(|&b| b == 0xFF)
|
||||||
})?,
|
.next()
|
||||||
|
.ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?,
|
||||||
)
|
)
|
||||||
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
|
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
|
||||||
.algorithm(),
|
.algorithm(),
|
||||||
)
|
)
|
||||||
})
|
}) {
|
||||||
{
|
|
||||||
*counts.entry(algorithm?).or_default() += UInt::from(1_u32);
|
*counts.entry(algorithm?).or_default() += UInt::from(1_u32);
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(counts)
|
Ok(counts)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_device_keys(
|
fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) -> Result<()> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
device_keys: &Raw<DeviceKeys>,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||||
userdeviceid.push(0xff);
|
userdeviceid.push(0xFF);
|
||||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||||
|
|
||||||
self.keyid_key.insert(
|
self.keyid_key.insert(
|
||||||
|
@ -465,39 +409,30 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_cross_signing_keys(
|
fn add_cross_signing_keys(
|
||||||
&self,
|
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||||
user_id: &UserId,
|
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
|
||||||
master_key: &Raw<CrossSigningKey>,
|
|
||||||
self_signing_key: &Option<Raw<CrossSigningKey>>,
|
|
||||||
user_signing_key: &Option<Raw<CrossSigningKey>>,
|
|
||||||
notify: bool,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// TODO: Check signatures
|
// TODO: Check signatures
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
|
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
|
||||||
|
|
||||||
self.keyid_key
|
self.keyid_key.insert(&master_key_key, master_key.json().get().as_bytes())?;
|
||||||
.insert(&master_key_key, master_key.json().get().as_bytes())?;
|
|
||||||
|
|
||||||
self.userid_masterkeyid
|
self.userid_masterkeyid.insert(user_id.as_bytes(), &master_key_key)?;
|
||||||
.insert(user_id.as_bytes(), &master_key_key)?;
|
|
||||||
|
|
||||||
// Self-signing key
|
// Self-signing key
|
||||||
if let Some(self_signing_key) = self_signing_key {
|
if let Some(self_signing_key) = self_signing_key {
|
||||||
let mut self_signing_key_ids = self_signing_key
|
let mut self_signing_key_ids = self_signing_key
|
||||||
.deserialize()
|
.deserialize()
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))?
|
||||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key")
|
|
||||||
})?
|
|
||||||
.keys
|
.keys
|
||||||
.into_values();
|
.into_values();
|
||||||
|
|
||||||
let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest(
|
let self_signing_key_id = self_signing_key_ids
|
||||||
ErrorKind::InvalidParam,
|
.next()
|
||||||
"Self signing key contained no key.",
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?;
|
||||||
))?;
|
|
||||||
|
|
||||||
if self_signing_key_ids.next().is_some() {
|
if self_signing_key_ids.next().is_some() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
|
@ -509,29 +444,22 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
let mut self_signing_key_key = prefix.clone();
|
let mut self_signing_key_key = prefix.clone();
|
||||||
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
|
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
|
||||||
|
|
||||||
self.keyid_key.insert(
|
self.keyid_key.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?;
|
||||||
&self_signing_key_key,
|
|
||||||
self_signing_key.json().get().as_bytes(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
self.userid_selfsigningkeyid
|
self.userid_selfsigningkeyid.insert(user_id.as_bytes(), &self_signing_key_key)?;
|
||||||
.insert(user_id.as_bytes(), &self_signing_key_key)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// User-signing key
|
// User-signing key
|
||||||
if let Some(user_signing_key) = user_signing_key {
|
if let Some(user_signing_key) = user_signing_key {
|
||||||
let mut user_signing_key_ids = user_signing_key
|
let mut user_signing_key_ids = user_signing_key
|
||||||
.deserialize()
|
.deserialize()
|
||||||
.map_err(|_| {
|
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))?
|
||||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key")
|
|
||||||
})?
|
|
||||||
.keys
|
.keys
|
||||||
.into_values();
|
.into_values();
|
||||||
|
|
||||||
let user_signing_key_id = user_signing_key_ids.next().ok_or(Error::BadRequest(
|
let user_signing_key_id = user_signing_key_ids
|
||||||
ErrorKind::InvalidParam,
|
.next()
|
||||||
"User signing key contained no key.",
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?;
|
||||||
))?;
|
|
||||||
|
|
||||||
if user_signing_key_ids.next().is_some() {
|
if user_signing_key_ids.next().is_some() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
|
@ -543,13 +471,9 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
let mut user_signing_key_key = prefix;
|
let mut user_signing_key_key = prefix;
|
||||||
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes());
|
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes());
|
||||||
|
|
||||||
self.keyid_key.insert(
|
self.keyid_key.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?;
|
||||||
&user_signing_key_key,
|
|
||||||
user_signing_key.json().get().as_bytes(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
self.userid_usersigningkeyid
|
self.userid_usersigningkeyid.insert(user_id.as_bytes(), &user_signing_key_key)?;
|
||||||
.insert(user_id.as_bytes(), &user_signing_key_key)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if notify {
|
if notify {
|
||||||
|
@ -560,21 +484,18 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn sign_key(
|
fn sign_key(
|
||||||
&self,
|
&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId,
|
||||||
target_id: &UserId,
|
|
||||||
key_id: &str,
|
|
||||||
signature: (String, String),
|
|
||||||
sender_id: &UserId,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut key = target_id.as_bytes().to_vec();
|
let mut key = target_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(key_id.as_bytes());
|
key.extend_from_slice(key_id.as_bytes());
|
||||||
|
|
||||||
let mut cross_signing_key: serde_json::Value =
|
let mut cross_signing_key: serde_json::Value = serde_json::from_slice(
|
||||||
serde_json::from_slice(&self.keyid_key.get(&key)?.ok_or(Error::BadRequest(
|
&self
|
||||||
ErrorKind::InvalidParam,
|
.keyid_key
|
||||||
"Tried to sign nonexistent key.",
|
.get(&key)?
|
||||||
))?)
|
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?,
|
||||||
|
)
|
||||||
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?;
|
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?;
|
||||||
|
|
||||||
let signatures = cross_signing_key
|
let signatures = cross_signing_key
|
||||||
|
@ -601,13 +522,10 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn keys_changed<'a>(
|
fn keys_changed<'a>(
|
||||||
&'a self,
|
&'a self, user_or_room_id: &str, from: u64, to: Option<u64>,
|
||||||
user_or_room_id: &str,
|
|
||||||
from: u64,
|
|
||||||
to: Option<u64>,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||||
let mut prefix = user_or_room_id.as_bytes().to_vec();
|
let mut prefix = user_or_room_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let mut start = prefix.clone();
|
let mut start = prefix.clone();
|
||||||
start.extend_from_slice(&(from + 1).to_be_bytes());
|
start.extend_from_slice(&(from + 1).to_be_bytes());
|
||||||
|
@ -619,7 +537,7 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
.iter_from(&start, false)
|
.iter_from(&start, false)
|
||||||
.take_while(move |(k, _)| {
|
.take_while(move |(k, _)| {
|
||||||
k.starts_with(&prefix)
|
k.starts_with(&prefix)
|
||||||
&& if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) {
|
&& if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) {
|
||||||
if let Ok(c) = utils::u64_from_bytes(current) {
|
if let Ok(c) = utils::u64_from_bytes(current) {
|
||||||
c <= to
|
c <= to
|
||||||
} else {
|
} else {
|
||||||
|
@ -632,83 +550,63 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
.map(|(_, bytes)| {
|
.map(|(_, bytes)| {
|
||||||
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
UserId::parse(
|
||||||
Error::bad_database(
|
utils::string_from_bytes(&bytes).map_err(|_| {
|
||||||
"User ID in devicekeychangeid_userid is invalid unicode.",
|
Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.")
|
||||||
|
})?,
|
||||||
)
|
)
|
||||||
})?)
|
.map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid."))
|
||||||
.map_err(|_| {
|
|
||||||
Error::bad_database("User ID in devicekeychangeid_userid is invalid.")
|
|
||||||
})
|
|
||||||
}),
|
}),
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
|
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
|
||||||
let count = services().globals.next_count()?.to_be_bytes();
|
let count = services().globals.next_count()?.to_be_bytes();
|
||||||
for room_id in services()
|
for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(std::result::Result::ok) {
|
||||||
.rooms
|
|
||||||
.state_cache
|
|
||||||
.rooms_joined(user_id)
|
|
||||||
.filter_map(std::result::Result::ok)
|
|
||||||
{
|
|
||||||
// Don't send key updates to unencrypted rooms
|
// Don't send key updates to unencrypted rooms
|
||||||
if services()
|
if services().rooms.state_accessor.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?.is_none()
|
||||||
.rooms
|
|
||||||
.state_accessor
|
|
||||||
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
|
|
||||||
.is_none()
|
|
||||||
{
|
{
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut key = room_id.as_bytes().to_vec();
|
let mut key = room_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(&count);
|
key.extend_from_slice(&count);
|
||||||
|
|
||||||
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
|
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(&count);
|
key.extend_from_slice(&count);
|
||||||
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
|
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_device_keys(
|
fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
) -> Result<Option<Raw<DeviceKeys>>> {
|
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(device_id.as_bytes());
|
key.extend_from_slice(device_id.as_bytes());
|
||||||
|
|
||||||
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
|
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
|
||||||
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
|
Ok(Some(
|
||||||
Error::bad_database("DeviceKeys in db are invalid.")
|
serde_json::from_slice(&bytes).map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?,
|
||||||
})?))
|
))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn parse_master_key(
|
fn parse_master_key(
|
||||||
&self,
|
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>,
|
||||||
user_id: &UserId,
|
|
||||||
master_key: &Raw<CrossSigningKey>,
|
|
||||||
) -> Result<(Vec<u8>, CrossSigningKey)> {
|
) -> Result<(Vec<u8>, CrossSigningKey)> {
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let master_key = master_key
|
let master_key =
|
||||||
.deserialize()
|
master_key.deserialize().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
|
||||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
|
|
||||||
let mut master_key_ids = master_key.keys.values();
|
let mut master_key_ids = master_key.keys.values();
|
||||||
let master_key_id = master_key_ids.next().ok_or(Error::BadRequest(
|
let master_key_id =
|
||||||
ErrorKind::InvalidParam,
|
master_key_ids.next().ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
|
||||||
"Master key contained no key.",
|
|
||||||
))?;
|
|
||||||
if master_key_ids.next().is_some() {
|
if master_key_ids.next().is_some() {
|
||||||
return Err(Error::BadRequest(
|
return Err(Error::BadRequest(
|
||||||
ErrorKind::InvalidParam,
|
ErrorKind::InvalidParam,
|
||||||
|
@ -721,79 +619,54 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_key(
|
fn get_key(
|
||||||
&self,
|
&self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||||
key: &[u8],
|
|
||||||
sender_user: Option<&UserId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
allowed_signatures: &dyn Fn(&UserId) -> bool,
|
|
||||||
) -> Result<Option<Raw<CrossSigningKey>>> {
|
) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||||
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
|
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
|
||||||
let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
|
let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
|
||||||
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
|
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
|
||||||
clean_signatures(
|
clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?;
|
||||||
&mut cross_signing_key,
|
|
||||||
sender_user,
|
|
||||||
user_id,
|
|
||||||
allowed_signatures,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(Some(Raw::from_json(
|
Ok(Some(Raw::from_json(
|
||||||
serde_json::value::to_raw_value(&cross_signing_key)
|
serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"),
|
||||||
.expect("Value to RawValue serialization"),
|
|
||||||
)))
|
)))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_master_key(
|
fn get_master_key(
|
||||||
&self,
|
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||||
sender_user: Option<&UserId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
allowed_signatures: &dyn Fn(&UserId) -> bool,
|
|
||||||
) -> Result<Option<Raw<CrossSigningKey>>> {
|
) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||||
self.userid_masterkeyid
|
self.userid_masterkeyid
|
||||||
.get(user_id.as_bytes())?
|
.get(user_id.as_bytes())?
|
||||||
.map_or(Ok(None), |key| {
|
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
|
||||||
self.get_key(&key, sender_user, user_id, allowed_signatures)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_self_signing_key(
|
fn get_self_signing_key(
|
||||||
&self,
|
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||||
sender_user: Option<&UserId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
allowed_signatures: &dyn Fn(&UserId) -> bool,
|
|
||||||
) -> Result<Option<Raw<CrossSigningKey>>> {
|
) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||||
self.userid_selfsigningkeyid
|
self.userid_selfsigningkeyid
|
||||||
.get(user_id.as_bytes())?
|
.get(user_id.as_bytes())?
|
||||||
.map_or(Ok(None), |key| {
|
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
|
||||||
self.get_key(&key, sender_user, user_id, allowed_signatures)
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
|
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||||
self.userid_usersigningkeyid
|
self.userid_usersigningkeyid.get(user_id.as_bytes())?.map_or(Ok(None), |key| {
|
||||||
.get(user_id.as_bytes())?
|
|
||||||
.map_or(Ok(None), |key| {
|
|
||||||
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
|
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
|
||||||
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
|
Ok(Some(
|
||||||
Error::bad_database("CrossSigningKey in db is invalid.")
|
serde_json::from_slice(&bytes)
|
||||||
})?))
|
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?,
|
||||||
|
))
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn add_to_device_event(
|
fn add_to_device_event(
|
||||||
&self,
|
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
|
||||||
sender: &UserId,
|
|
||||||
target_user_id: &UserId,
|
|
||||||
target_device_id: &DeviceId,
|
|
||||||
event_type: &str,
|
|
||||||
content: serde_json::Value,
|
content: serde_json::Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut key = target_user_id.as_bytes().to_vec();
|
let mut key = target_user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(target_device_id.as_bytes());
|
key.extend_from_slice(target_device_id.as_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||||
|
|
||||||
let mut json = serde_json::Map::new();
|
let mut json = serde_json::Map::new();
|
||||||
|
@ -808,17 +681,13 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_to_device_events(
|
fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
) -> Result<Vec<Raw<AnyToDeviceEvent>>> {
|
|
||||||
let mut events = Vec::new();
|
let mut events = Vec::new();
|
||||||
|
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(device_id.as_bytes());
|
prefix.extend_from_slice(device_id.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
|
for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
|
||||||
events.push(
|
events.push(
|
||||||
|
@ -830,16 +699,11 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
Ok(events)
|
Ok(events)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove_to_device_events(
|
fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
until: u64,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut prefix = user_id.as_bytes().to_vec();
|
let mut prefix = user_id.as_bytes().to_vec();
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
prefix.extend_from_slice(device_id.as_bytes());
|
prefix.extend_from_slice(device_id.as_bytes());
|
||||||
prefix.push(0xff);
|
prefix.push(0xFF);
|
||||||
|
|
||||||
let mut last = prefix.clone();
|
let mut last = prefix.clone();
|
||||||
last.extend_from_slice(&until.to_be_bytes());
|
last.extend_from_slice(&until.to_be_bytes());
|
||||||
|
@ -864,26 +728,25 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
|
||||||
fn update_device_metadata(
|
fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
device: &Device,
|
|
||||||
) -> Result<()> {
|
|
||||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||||
userdeviceid.push(0xff);
|
userdeviceid.push(0xFF);
|
||||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||||
|
|
||||||
// Only existing devices should be able to call this, but we shouldn't assert either...
|
// Only existing devices should be able to call this, but we shouldn't assert
|
||||||
|
// either...
|
||||||
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
|
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
|
||||||
warn!("Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", user_id, device_id);
|
warn!(
|
||||||
|
"Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \
|
||||||
|
metadata in database",
|
||||||
|
user_id, device_id
|
||||||
|
);
|
||||||
return Err(Error::bad_database(
|
return Err(Error::bad_database(
|
||||||
"User does not exist or device ID has no metadata in database.",
|
"User does not exist or device ID has no metadata in database.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
|
|
||||||
self.userid_devicelistversion
|
self.userid_devicelistversion.increment(user_id.as_bytes())?;
|
||||||
.increment(user_id.as_bytes())?;
|
|
||||||
|
|
||||||
self.userdeviceid_metadata.insert(
|
self.userdeviceid_metadata.insert(
|
||||||
&userdeviceid,
|
&userdeviceid,
|
||||||
|
@ -894,18 +757,12 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Get device metadata.
|
/// Get device metadata.
|
||||||
fn get_device_metadata(
|
fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
device_id: &DeviceId,
|
|
||||||
) -> Result<Option<Device>> {
|
|
||||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||||
userdeviceid.push(0xff);
|
userdeviceid.push(0xFF);
|
||||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||||
|
|
||||||
self.userdeviceid_metadata
|
self.userdeviceid_metadata.get(&userdeviceid)?.map_or(Ok(None), |bytes| {
|
||||||
.get(&userdeviceid)?
|
|
||||||
.map_or(Ok(None), |bytes| {
|
|
||||||
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
|
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
|
||||||
Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
|
Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
|
||||||
})?))
|
})?))
|
||||||
|
@ -913,31 +770,19 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
|
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
|
||||||
self.userid_devicelistversion
|
self.userid_devicelistversion.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||||
.get(user_id.as_bytes())?
|
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid devicelistversion in db.")).map(Some)
|
||||||
.map_or(Ok(None), |bytes| {
|
|
||||||
utils::u64_from_bytes(&bytes)
|
|
||||||
.map_err(|_| Error::bad_database("Invalid devicelistversion in db."))
|
|
||||||
.map(Some)
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn all_devices_metadata<'a>(
|
fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
|
||||||
&'a self,
|
|
||||||
user_id: &UserId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
|
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
|
|
||||||
Box::new(
|
Box::new(self.userdeviceid_metadata.scan_prefix(key).map(|(_, bytes)| {
|
||||||
self.userdeviceid_metadata
|
serde_json::from_slice::<Device>(&bytes)
|
||||||
.scan_prefix(key)
|
.map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid."))
|
||||||
.map(|(_, bytes)| {
|
}))
|
||||||
serde_json::from_slice::<Device>(&bytes).map_err(|_| {
|
|
||||||
Error::bad_database("Device in userdeviceid_metadata is invalid.")
|
|
||||||
})
|
|
||||||
}),
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a new sync filter. Returns the filter id.
|
/// Creates a new sync filter. Returns the filter id.
|
||||||
|
@ -945,27 +790,23 @@ impl service::users::Data for KeyValueDatabase {
|
||||||
let filter_id = utils::random_string(4);
|
let filter_id = utils::random_string(4);
|
||||||
|
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(filter_id.as_bytes());
|
key.extend_from_slice(filter_id.as_bytes());
|
||||||
|
|
||||||
self.userfilterid_filter.insert(
|
self.userfilterid_filter.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?;
|
||||||
&key,
|
|
||||||
&serde_json::to_vec(&filter).expect("filter is valid json"),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
Ok(filter_id)
|
Ok(filter_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> {
|
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> {
|
||||||
let mut key = user_id.as_bytes().to_vec();
|
let mut key = user_id.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(filter_id.as_bytes());
|
key.extend_from_slice(filter_id.as_bytes());
|
||||||
|
|
||||||
let raw = self.userfilterid_filter.get(&key)?;
|
let raw = self.userfilterid_filter.get(&key)?;
|
||||||
|
|
||||||
if let Some(raw) = raw {
|
if let Some(raw) = raw {
|
||||||
serde_json::from_slice(&raw)
|
serde_json::from_slice(&raw).map_err(|_| Error::bad_database("Invalid filter event in db."))
|
||||||
.map_err(|_| Error::bad_database("Invalid filter event in db."))
|
|
||||||
} else {
|
} else {
|
||||||
Ok(None)
|
Ok(None)
|
||||||
}
|
}
|
||||||
|
@ -976,8 +817,8 @@ impl KeyValueDatabase {}
|
||||||
|
|
||||||
/// Will only return with Some(username) if the password was not empty and the
|
/// Will only return with Some(username) if the password was not empty and the
|
||||||
/// username could be successfully parsed.
|
/// username could be successfully parsed.
|
||||||
/// If utils::string_from_bytes(...) returns an error that username will be skipped
|
/// If utils::string_from_bytes(...) returns an error that username will be
|
||||||
/// and the error will be logged.
|
/// skipped and the error will be logged.
|
||||||
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> {
|
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> {
|
||||||
// A valid password is not empty
|
// A valid password is not empty
|
||||||
if password.is_empty() {
|
if password.is_empty() {
|
||||||
|
@ -986,12 +827,9 @@ fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<
|
||||||
match utils::string_from_bytes(username) {
|
match utils::string_from_bytes(username) {
|
||||||
Ok(u) => Some(u),
|
Ok(u) => Some(u),
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!(
|
warn!("Failed to parse username while calling get_local_users(): {}", e.to_string());
|
||||||
"Failed to parse username while calling get_local_users(): {}",
|
|
||||||
e.to_string()
|
|
||||||
);
|
|
||||||
None
|
None
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,10 +1,16 @@
|
||||||
pub(crate) mod abstraction;
|
pub(crate) mod abstraction;
|
||||||
pub(crate) mod key_value;
|
pub(crate) mod key_value;
|
||||||
|
|
||||||
use crate::{
|
use std::{
|
||||||
service::rooms::{edus::presence::presence_handler, timeline::PduCount},
|
collections::{BTreeMap, HashMap, HashSet},
|
||||||
services, utils, Config, Error, PduEvent, Result, Services, SERVICES,
|
fs::{self},
|
||||||
|
io::Write,
|
||||||
|
mem::size_of,
|
||||||
|
path::Path,
|
||||||
|
sync::{Arc, Mutex, RwLock},
|
||||||
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
use abstraction::{KeyValueDatabaseEngine, KvTree};
|
use abstraction::{KeyValueDatabaseEngine, KvTree};
|
||||||
use argon2::{password_hash::SaltString, PasswordHasher, PasswordVerifier};
|
use argon2::{password_hash::SaltString, PasswordHasher, PasswordVerifier};
|
||||||
use itertools::Itertools;
|
use itertools::Itertools;
|
||||||
|
@ -18,23 +24,17 @@ use ruma::{
|
||||||
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
|
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
|
||||||
},
|
},
|
||||||
push::Ruleset,
|
push::Ruleset,
|
||||||
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId,
|
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
|
||||||
UserId,
|
|
||||||
};
|
};
|
||||||
use serde::Deserialize;
|
use serde::Deserialize;
|
||||||
use std::{
|
|
||||||
collections::{BTreeMap, HashMap, HashSet},
|
|
||||||
fs::{self},
|
|
||||||
io::Write,
|
|
||||||
mem::size_of,
|
|
||||||
path::Path,
|
|
||||||
sync::{Arc, Mutex, RwLock},
|
|
||||||
time::Duration,
|
|
||||||
};
|
|
||||||
use tokio::{sync::mpsc, time::interval};
|
use tokio::{sync::mpsc, time::interval};
|
||||||
|
|
||||||
use tracing::{debug, error, info, warn};
|
use tracing::{debug, error, info, warn};
|
||||||
|
|
||||||
|
use crate::{
|
||||||
|
service::rooms::{edus::presence::presence_handler, timeline::PduCount},
|
||||||
|
services, utils, Config, Error, PduEvent, Result, Services, SERVICES,
|
||||||
|
};
|
||||||
|
|
||||||
pub struct KeyValueDatabase {
|
pub struct KeyValueDatabase {
|
||||||
db: Arc<dyn KeyValueDatabaseEngine>,
|
db: Arc<dyn KeyValueDatabaseEngine>,
|
||||||
|
|
||||||
|
@ -128,12 +128,15 @@ pub struct KeyValueDatabase {
|
||||||
pub(super) eventid_shorteventid: Arc<dyn KvTree>,
|
pub(super) eventid_shorteventid: Arc<dyn KvTree>,
|
||||||
|
|
||||||
pub(super) statehash_shortstatehash: Arc<dyn KvTree>,
|
pub(super) statehash_shortstatehash: Arc<dyn KvTree>,
|
||||||
pub(super) shortstatehash_statediff: Arc<dyn KvTree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--)
|
pub(super) shortstatehash_statediff: Arc<dyn KvTree>, /* StateDiff = parent (or 0) +
|
||||||
|
* (shortstatekey+shorteventid++) + 0_u64 +
|
||||||
|
* (shortstatekey+shorteventid--) */
|
||||||
|
|
||||||
pub(super) shorteventid_authchain: Arc<dyn KvTree>,
|
pub(super) shorteventid_authchain: Arc<dyn KvTree>,
|
||||||
|
|
||||||
/// RoomId + EventId -> outlier PDU.
|
/// RoomId + EventId -> outlier PDU.
|
||||||
/// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn.
|
/// Any pdu that has passed the steps 1-8 in the incoming event
|
||||||
|
/// /federation/send/txn.
|
||||||
pub(super) eventid_outlierpdu: Arc<dyn KvTree>,
|
pub(super) eventid_outlierpdu: Arc<dyn KvTree>,
|
||||||
pub(super) softfailedeventids: Arc<dyn KvTree>,
|
pub(super) softfailedeventids: Arc<dyn KvTree>,
|
||||||
|
|
||||||
|
@ -155,11 +158,14 @@ pub struct KeyValueDatabase {
|
||||||
pub(super) backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId
|
pub(super) backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId
|
||||||
|
|
||||||
//pub transaction_ids: transaction_ids::TransactionIds,
|
//pub transaction_ids: transaction_ids::TransactionIds,
|
||||||
pub(super) userdevicetxnid_response: Arc<dyn KvTree>, // Response can be empty (/sendToDevice) or the event id (/send)
|
pub(super) userdevicetxnid_response: Arc<dyn KvTree>, /* Response can be empty (/sendToDevice) or the event id
|
||||||
|
* (/send) */
|
||||||
//pub sending: sending::Sending,
|
//pub sending: sending::Sending,
|
||||||
pub(super) servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync
|
pub(super) servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync
|
||||||
pub(super) servernameevent_data: Arc<dyn KvTree>, // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content
|
pub(super) servernameevent_data: Arc<dyn KvTree>, /* ServernameEvent = (+ / $)SenderKey / ServerName / UserId +
|
||||||
pub(super) servercurrentevent_data: Arc<dyn KvTree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content
|
* PduId / Id (for edus), Data = EDU content */
|
||||||
|
pub(super) servercurrentevent_data: Arc<dyn KvTree>, /* ServerCurrentEvents = (+ / $)ServerName / UserId + PduId
|
||||||
|
* / Id (for edus), Data = EDU content */
|
||||||
|
|
||||||
//pub appservice: appservice::Appservice,
|
//pub appservice: appservice::Appservice,
|
||||||
pub(super) id_appserviceregistrations: Arc<dyn KvTree>,
|
pub(super) id_appserviceregistrations: Arc<dyn KvTree>,
|
||||||
|
@ -223,10 +229,14 @@ impl KeyValueDatabase {
|
||||||
|
|
||||||
if !Path::new(&config.database_path).exists() {
|
if !Path::new(&config.database_path).exists() {
|
||||||
debug!("Database path does not exist, assuming this is a new setup and creating it");
|
debug!("Database path does not exist, assuming this is a new setup and creating it");
|
||||||
std::fs::create_dir_all(&config.database_path)
|
std::fs::create_dir_all(&config.database_path).map_err(|e| {
|
||||||
.map_err(|e| {
|
|
||||||
error!("Failed to create database path: {e}");
|
error!("Failed to create database path: {e}");
|
||||||
Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself or allow conduwuit the permissions to create directories and files.")})?;
|
Error::BadConfig(
|
||||||
|
"Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please \
|
||||||
|
create the database folder yourself or allow conduwuit the permissions to create directories and \
|
||||||
|
files.",
|
||||||
|
)
|
||||||
|
})?;
|
||||||
}
|
}
|
||||||
|
|
||||||
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend {
|
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend {
|
||||||
|
@ -236,17 +246,19 @@ impl KeyValueDatabase {
|
||||||
return Err(Error::BadConfig("Database backend not found."));
|
return Err(Error::BadConfig("Database backend not found."));
|
||||||
#[cfg(feature = "sqlite")]
|
#[cfg(feature = "sqlite")]
|
||||||
Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
|
Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
|
||||||
}
|
},
|
||||||
"rocksdb" => {
|
"rocksdb" => {
|
||||||
debug!("Got rocksdb database backend");
|
debug!("Got rocksdb database backend");
|
||||||
#[cfg(not(feature = "rocksdb"))]
|
#[cfg(not(feature = "rocksdb"))]
|
||||||
return Err(Error::BadConfig("Database backend not found."));
|
return Err(Error::BadConfig("Database backend not found."));
|
||||||
#[cfg(feature = "rocksdb")]
|
#[cfg(feature = "rocksdb")]
|
||||||
Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
|
Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
|
||||||
}
|
},
|
||||||
_ => {
|
_ => {
|
||||||
return Err(Error::BadConfig("Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends."));
|
return Err(Error::BadConfig(
|
||||||
}
|
"Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.",
|
||||||
|
));
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let (presence_sender, presence_receiver) = mpsc::unbounded_channel();
|
let (presence_sender, presence_receiver) = mpsc::unbounded_channel();
|
||||||
|
@ -275,8 +287,7 @@ impl KeyValueDatabase {
|
||||||
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
|
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
|
||||||
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
|
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
|
||||||
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
|
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
|
||||||
roomuserid_lastprivatereadupdate: builder
|
roomuserid_lastprivatereadupdate: builder.open_tree("roomuserid_lastprivatereadupdate")?,
|
||||||
.open_tree("roomuserid_lastprivatereadupdate")?,
|
|
||||||
typingid_userid: builder.open_tree("typingid_userid")?,
|
typingid_userid: builder.open_tree("typingid_userid")?,
|
||||||
roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?,
|
roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?,
|
||||||
roomuserid_presence: builder.open_tree("roomuserid_presence")?,
|
roomuserid_presence: builder.open_tree("roomuserid_presence")?,
|
||||||
|
@ -352,14 +363,9 @@ impl KeyValueDatabase {
|
||||||
|
|
||||||
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
|
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
|
||||||
pdu_cache: Mutex::new(LruCache::new(
|
pdu_cache: Mutex::new(LruCache::new(
|
||||||
config
|
config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"),
|
||||||
.pdu_cache_capacity
|
|
||||||
.try_into()
|
|
||||||
.expect("pdu cache capacity fits into usize"),
|
|
||||||
)),
|
|
||||||
auth_chain_cache: Mutex::new(LruCache::new(
|
|
||||||
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
|
|
||||||
)),
|
)),
|
||||||
|
auth_chain_cache: Mutex::new(LruCache::new((100_000.0 * config.conduit_cache_capacity_modifier) as usize)),
|
||||||
shorteventid_cache: Mutex::new(LruCache::new(
|
shorteventid_cache: Mutex::new(LruCache::new(
|
||||||
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
|
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||||
)),
|
)),
|
||||||
|
@ -388,17 +394,14 @@ impl KeyValueDatabase {
|
||||||
// Matrix resource ownership is based on the server name; changing it
|
// Matrix resource ownership is based on the server name; changing it
|
||||||
// requires recreating the database from scratch.
|
// requires recreating the database from scratch.
|
||||||
if services().users.count()? > 0 {
|
if services().users.count()? > 0 {
|
||||||
let conduit_user =
|
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
|
||||||
UserId::parse_with_server_name("conduit", services().globals.server_name())
|
|
||||||
.expect("@conduit:server_name is valid");
|
.expect("@conduit:server_name is valid");
|
||||||
|
|
||||||
if !services().users.exists(&conduit_user)? {
|
if !services().users.exists(&conduit_user)? {
|
||||||
error!(
|
error!("The {} server user does not exist, and the database is not new.", conduit_user);
|
||||||
"The {} server user does not exist, and the database is not new.",
|
|
||||||
conduit_user
|
|
||||||
);
|
|
||||||
return Err(Error::bad_database(
|
return Err(Error::bad_database(
|
||||||
"Cannot reuse an existing database after changing the server name, please delete the old one first."
|
"Cannot reuse an existing database after changing the server name, please delete the old one \
|
||||||
|
first.",
|
||||||
));
|
));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -415,17 +418,17 @@ impl KeyValueDatabase {
|
||||||
// MIGRATIONS
|
// MIGRATIONS
|
||||||
if services().globals.database_version()? < 1 {
|
if services().globals.database_version()? < 1 {
|
||||||
for (roomserverid, _) in db.roomserverids.iter() {
|
for (roomserverid, _) in db.roomserverids.iter() {
|
||||||
let mut parts = roomserverid.split(|&b| b == 0xff);
|
let mut parts = roomserverid.split(|&b| b == 0xFF);
|
||||||
let room_id = parts.next().expect("split always returns one element");
|
let room_id = parts.next().expect("split always returns one element");
|
||||||
let servername = match parts.next() {
|
let servername = match parts.next() {
|
||||||
Some(s) => s,
|
Some(s) => s,
|
||||||
None => {
|
None => {
|
||||||
error!("Migration: Invalid roomserverid in db.");
|
error!("Migration: Invalid roomserverid in db.");
|
||||||
continue;
|
continue;
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
let mut serverroomid = servername.to_vec();
|
let mut serverroomid = servername.to_vec();
|
||||||
serverroomid.push(0xff);
|
serverroomid.push(0xFF);
|
||||||
serverroomid.extend_from_slice(room_id);
|
serverroomid.extend_from_slice(room_id);
|
||||||
|
|
||||||
db.serverroomids.insert(&serverroomid, &[])?;
|
db.serverroomids.insert(&serverroomid, &[])?;
|
||||||
|
@ -445,11 +448,8 @@ impl KeyValueDatabase {
|
||||||
.argon
|
.argon
|
||||||
.hash_password(b"", &salt)
|
.hash_password(b"", &salt)
|
||||||
.expect("our own password to be properly hashed");
|
.expect("our own password to be properly hashed");
|
||||||
let empty_hashed_password = services()
|
let empty_hashed_password =
|
||||||
.globals
|
services().globals.argon.verify_password(&password, &empty_pass).is_ok();
|
||||||
.argon
|
|
||||||
.verify_password(&password, &empty_pass)
|
|
||||||
.is_ok();
|
|
||||||
|
|
||||||
if empty_hashed_password {
|
if empty_hashed_password {
|
||||||
db.userid_password.insert(&userid, b"")?;
|
db.userid_password.insert(&userid, b"")?;
|
||||||
|
@ -506,19 +506,18 @@ impl KeyValueDatabase {
|
||||||
if services().globals.database_version()? < 5 {
|
if services().globals.database_version()? < 5 {
|
||||||
// Upgrade user data store
|
// Upgrade user data store
|
||||||
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() {
|
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() {
|
||||||
let mut parts = roomuserdataid.split(|&b| b == 0xff);
|
let mut parts = roomuserdataid.split(|&b| b == 0xFF);
|
||||||
let room_id = parts.next().unwrap();
|
let room_id = parts.next().unwrap();
|
||||||
let user_id = parts.next().unwrap();
|
let user_id = parts.next().unwrap();
|
||||||
let event_type = roomuserdataid.rsplit(|&b| b == 0xff).next().unwrap();
|
let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap();
|
||||||
|
|
||||||
let mut key = room_id.to_vec();
|
let mut key = room_id.to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(user_id);
|
key.extend_from_slice(user_id);
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(event_type);
|
key.extend_from_slice(event_type);
|
||||||
|
|
||||||
db.roomusertype_roomuserdataid
|
db.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
|
||||||
.insert(&key, &roomuserdataid)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
services().globals.bump_database_version(5)?;
|
services().globals.bump_database_version(5)?;
|
||||||
|
@ -547,8 +546,7 @@ impl KeyValueDatabase {
|
||||||
let mut current_state = HashSet::new();
|
let mut current_state = HashSet::new();
|
||||||
let mut counter = 0;
|
let mut counter = 0;
|
||||||
|
|
||||||
let mut handle_state =
|
let mut handle_state = |current_sstatehash: u64,
|
||||||
|current_sstatehash: u64,
|
|
||||||
current_room: &RoomId,
|
current_room: &RoomId,
|
||||||
current_state: HashSet<_>,
|
current_state: HashSet<_>,
|
||||||
last_roomstates: &mut HashMap<_, _>| {
|
last_roomstates: &mut HashMap<_, _>| {
|
||||||
|
@ -558,25 +556,16 @@ impl KeyValueDatabase {
|
||||||
let states_parents = last_roomsstatehash.map_or_else(
|
let states_parents = last_roomsstatehash.map_or_else(
|
||||||
|| Ok(Vec::new()),
|
|| Ok(Vec::new()),
|
||||||
|&last_roomsstatehash| {
|
|&last_roomsstatehash| {
|
||||||
services()
|
services().rooms.state_compressor.load_shortstatehash_info(last_roomsstatehash)
|
||||||
.rooms
|
|
||||||
.state_compressor
|
|
||||||
.load_shortstatehash_info(last_roomsstatehash)
|
|
||||||
},
|
},
|
||||||
)?;
|
)?;
|
||||||
|
|
||||||
let (statediffnew, statediffremoved) =
|
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
|
||||||
if let Some(parent_stateinfo) = states_parents.last() {
|
let statediffnew =
|
||||||
let statediffnew = current_state
|
current_state.difference(&parent_stateinfo.1).copied().collect::<HashSet<_>>();
|
||||||
.difference(&parent_stateinfo.1)
|
|
||||||
.copied()
|
|
||||||
.collect::<HashSet<_>>();
|
|
||||||
|
|
||||||
let statediffremoved = parent_stateinfo
|
let statediffremoved =
|
||||||
.1
|
parent_stateinfo.1.difference(¤t_state).copied().collect::<HashSet<_>>();
|
||||||
.difference(¤t_state)
|
|
||||||
.copied()
|
|
||||||
.collect::<HashSet<_>>();
|
|
||||||
|
|
||||||
(statediffnew, statediffremoved)
|
(statediffnew, statediffremoved)
|
||||||
} else {
|
} else {
|
||||||
|
@ -617,8 +606,8 @@ impl KeyValueDatabase {
|
||||||
};
|
};
|
||||||
|
|
||||||
for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() {
|
for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() {
|
||||||
let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()])
|
let sstatehash =
|
||||||
.expect("number of bytes is correct");
|
utils::u64_from_bytes(&k[0..size_of::<u64>()]).expect("number of bytes is correct");
|
||||||
let sstatekey = k[size_of::<u64>()..].to_vec();
|
let sstatekey = k[size_of::<u64>()..].to_vec();
|
||||||
if Some(sstatehash) != current_sstatehash {
|
if Some(sstatehash) != current_sstatehash {
|
||||||
if let Some(current_sstatehash) = current_sstatehash {
|
if let Some(current_sstatehash) = current_sstatehash {
|
||||||
|
@ -628,8 +617,7 @@ impl KeyValueDatabase {
|
||||||
current_state,
|
current_state,
|
||||||
&mut last_roomstates,
|
&mut last_roomstates,
|
||||||
)?;
|
)?;
|
||||||
last_roomstates
|
last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash);
|
||||||
.insert(current_room.clone().unwrap(), current_sstatehash);
|
|
||||||
}
|
}
|
||||||
current_state = HashSet::new();
|
current_state = HashSet::new();
|
||||||
current_sstatehash = Some(sstatehash);
|
current_sstatehash = Some(sstatehash);
|
||||||
|
@ -637,12 +625,7 @@ impl KeyValueDatabase {
|
||||||
let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap();
|
let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap();
|
||||||
let string = utils::string_from_bytes(&event_id).unwrap();
|
let string = utils::string_from_bytes(&event_id).unwrap();
|
||||||
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
|
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
|
||||||
let pdu = services()
|
let pdu = services().rooms.timeline.get_pdu(event_id).unwrap().unwrap();
|
||||||
.rooms
|
|
||||||
.timeline
|
|
||||||
.get_pdu(event_id)
|
|
||||||
.unwrap()
|
|
||||||
.unwrap();
|
|
||||||
|
|
||||||
if Some(&pdu.room_id) != current_room.as_ref() {
|
if Some(&pdu.room_id) != current_room.as_ref() {
|
||||||
current_room = Some(pdu.room_id.clone());
|
current_room = Some(pdu.room_id.clone());
|
||||||
|
@ -680,15 +663,11 @@ impl KeyValueDatabase {
|
||||||
if !key.starts_with(b"!") {
|
if !key.starts_with(b"!") {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
let mut parts = key.splitn(2, |&b| b == 0xff);
|
let mut parts = key.splitn(2, |&b| b == 0xFF);
|
||||||
let room_id = parts.next().unwrap();
|
let room_id = parts.next().unwrap();
|
||||||
let count = parts.next().unwrap();
|
let count = parts.next().unwrap();
|
||||||
|
|
||||||
let short_room_id = db
|
let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist");
|
||||||
.roomid_shortroomid
|
|
||||||
.get(room_id)
|
|
||||||
.unwrap()
|
|
||||||
.expect("shortroomid should exist");
|
|
||||||
|
|
||||||
let mut new_key = short_room_id;
|
let mut new_key = short_room_id;
|
||||||
new_key.extend_from_slice(count);
|
new_key.extend_from_slice(count);
|
||||||
|
@ -702,15 +681,11 @@ impl KeyValueDatabase {
|
||||||
if !value.starts_with(b"!") {
|
if !value.starts_with(b"!") {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
let mut parts = value.splitn(2, |&b| b == 0xff);
|
let mut parts = value.splitn(2, |&b| b == 0xFF);
|
||||||
let room_id = parts.next().unwrap();
|
let room_id = parts.next().unwrap();
|
||||||
let count = parts.next().unwrap();
|
let count = parts.next().unwrap();
|
||||||
|
|
||||||
let short_room_id = db
|
let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist");
|
||||||
.roomid_shortroomid
|
|
||||||
.get(room_id)
|
|
||||||
.unwrap()
|
|
||||||
.expect("shortroomid should exist");
|
|
||||||
|
|
||||||
let mut new_value = short_room_id;
|
let mut new_value = short_room_id;
|
||||||
new_value.extend_from_slice(count);
|
new_value.extend_from_slice(count);
|
||||||
|
@ -734,20 +709,17 @@ impl KeyValueDatabase {
|
||||||
if !key.starts_with(b"!") {
|
if !key.starts_with(b"!") {
|
||||||
return None;
|
return None;
|
||||||
}
|
}
|
||||||
let mut parts = key.splitn(4, |&b| b == 0xff);
|
let mut parts = key.splitn(4, |&b| b == 0xFF);
|
||||||
let room_id = parts.next().unwrap();
|
let room_id = parts.next().unwrap();
|
||||||
let word = parts.next().unwrap();
|
let word = parts.next().unwrap();
|
||||||
let _pdu_id_room = parts.next().unwrap();
|
let _pdu_id_room = parts.next().unwrap();
|
||||||
let pdu_id_count = parts.next().unwrap();
|
let pdu_id_count = parts.next().unwrap();
|
||||||
|
|
||||||
let short_room_id = db
|
let short_room_id =
|
||||||
.roomid_shortroomid
|
db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist");
|
||||||
.get(room_id)
|
|
||||||
.unwrap()
|
|
||||||
.expect("shortroomid should exist");
|
|
||||||
let mut new_key = short_room_id;
|
let mut new_key = short_room_id;
|
||||||
new_key.extend_from_slice(word);
|
new_key.extend_from_slice(word);
|
||||||
new_key.push(0xff);
|
new_key.push(0xFF);
|
||||||
new_key.extend_from_slice(pdu_id_count);
|
new_key.extend_from_slice(pdu_id_count);
|
||||||
Some((new_key, Vec::new()))
|
Some((new_key, Vec::new()))
|
||||||
})
|
})
|
||||||
|
@ -784,8 +756,7 @@ impl KeyValueDatabase {
|
||||||
if services().globals.database_version()? < 10 {
|
if services().globals.database_version()? < 10 {
|
||||||
// Add other direction for shortstatekeys
|
// Add other direction for shortstatekeys
|
||||||
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() {
|
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() {
|
||||||
db.shortstatekey_statekey
|
db.shortstatekey_statekey.insert(&shortstatekey, &statekey)?;
|
||||||
.insert(&shortstatekey, &statekey)?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Force E2EE device list updates so we can send them over federation
|
// Force E2EE device list updates so we can send them over federation
|
||||||
|
@ -799,9 +770,7 @@ impl KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
if services().globals.database_version()? < 11 {
|
if services().globals.database_version()? < 11 {
|
||||||
db.db
|
db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?;
|
||||||
.open_tree("userdevicesessionid_uiaarequest")?
|
|
||||||
.clear()?;
|
|
||||||
services().globals.bump_database_version(11)?;
|
services().globals.bump_database_version(11)?;
|
||||||
|
|
||||||
warn!("Migration: 10 -> 11 finished");
|
warn!("Migration: 10 -> 11 finished");
|
||||||
|
@ -809,43 +778,33 @@ impl KeyValueDatabase {
|
||||||
|
|
||||||
if services().globals.database_version()? < 12 {
|
if services().globals.database_version()? < 12 {
|
||||||
for username in services().users.list_local_users()? {
|
for username in services().users.list_local_users()? {
|
||||||
let user = match UserId::parse_with_server_name(
|
let user = match UserId::parse_with_server_name(username.clone(), services().globals.server_name())
|
||||||
username.clone(),
|
{
|
||||||
services().globals.server_name(),
|
|
||||||
) {
|
|
||||||
Ok(u) => u,
|
Ok(u) => u,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Invalid username {username}: {e}");
|
warn!("Invalid username {username}: {e}");
|
||||||
continue;
|
continue;
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let raw_rules_list = services()
|
let raw_rules_list = services()
|
||||||
.account_data
|
.account_data
|
||||||
.get(
|
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
|
||||||
None,
|
|
||||||
&user,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.expect("Username is invalid");
|
.expect("Username is invalid");
|
||||||
|
|
||||||
let mut account_data =
|
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
|
||||||
serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
|
|
||||||
let rules_list = &mut account_data.content.global;
|
let rules_list = &mut account_data.content.global;
|
||||||
|
|
||||||
//content rule
|
//content rule
|
||||||
{
|
{
|
||||||
let content_rule_transformation =
|
let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"];
|
||||||
[".m.rules.contains_user_name", ".m.rule.contains_user_name"];
|
|
||||||
|
|
||||||
let rule = rules_list.content.get(content_rule_transformation[0]);
|
let rule = rules_list.content.get(content_rule_transformation[0]);
|
||||||
if rule.is_some() {
|
if rule.is_some() {
|
||||||
let mut rule = rule.unwrap().clone();
|
let mut rule = rule.unwrap().clone();
|
||||||
rule.rule_id = content_rule_transformation[1].to_owned();
|
rule.rule_id = content_rule_transformation[1].to_owned();
|
||||||
rules_list
|
rules_list.content.shift_remove(content_rule_transformation[0]);
|
||||||
.content
|
|
||||||
.shift_remove(content_rule_transformation[0]);
|
|
||||||
rules_list.content.insert(rule);
|
rules_list.content.insert(rule);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -855,10 +814,7 @@ impl KeyValueDatabase {
|
||||||
let underride_rule_transformation = [
|
let underride_rule_transformation = [
|
||||||
[".m.rules.call", ".m.rule.call"],
|
[".m.rules.call", ".m.rule.call"],
|
||||||
[".m.rules.room_one_to_one", ".m.rule.room_one_to_one"],
|
[".m.rules.room_one_to_one", ".m.rule.room_one_to_one"],
|
||||||
[
|
[".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one"],
|
||||||
".m.rules.encrypted_room_one_to_one",
|
|
||||||
".m.rule.encrypted_room_one_to_one",
|
|
||||||
],
|
|
||||||
[".m.rules.message", ".m.rule.message"],
|
[".m.rules.message", ".m.rule.message"],
|
||||||
[".m.rules.encrypted", ".m.rule.encrypted"],
|
[".m.rules.encrypted", ".m.rule.encrypted"],
|
||||||
];
|
];
|
||||||
|
@ -887,38 +843,29 @@ impl KeyValueDatabase {
|
||||||
warn!("Migration: 11 -> 12 finished");
|
warn!("Migration: 11 -> 12 finished");
|
||||||
}
|
}
|
||||||
|
|
||||||
// This migration can be reused as-is anytime the server-default rules are updated.
|
// This migration can be reused as-is anytime the server-default rules are
|
||||||
|
// updated.
|
||||||
if services().globals.database_version()? < 13 {
|
if services().globals.database_version()? < 13 {
|
||||||
for username in services().users.list_local_users()? {
|
for username in services().users.list_local_users()? {
|
||||||
let user = match UserId::parse_with_server_name(
|
let user = match UserId::parse_with_server_name(username.clone(), services().globals.server_name())
|
||||||
username.clone(),
|
{
|
||||||
services().globals.server_name(),
|
|
||||||
) {
|
|
||||||
Ok(u) => u,
|
Ok(u) => u,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Invalid username {username}: {e}");
|
warn!("Invalid username {username}: {e}");
|
||||||
continue;
|
continue;
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let raw_rules_list = services()
|
let raw_rules_list = services()
|
||||||
.account_data
|
.account_data
|
||||||
.get(
|
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
|
||||||
None,
|
|
||||||
&user,
|
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
|
||||||
)
|
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.expect("Username is invalid");
|
.expect("Username is invalid");
|
||||||
|
|
||||||
let mut account_data =
|
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
|
||||||
serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
|
|
||||||
|
|
||||||
let user_default_rules = ruma::push::Ruleset::server_default(&user);
|
let user_default_rules = ruma::push::Ruleset::server_default(&user);
|
||||||
account_data
|
account_data.content.global.update_with_server_default(user_default_rules);
|
||||||
.content
|
|
||||||
.global
|
|
||||||
.update_with_server_default(user_default_rules);
|
|
||||||
|
|
||||||
services().account_data.update(
|
services().account_data.update(
|
||||||
None,
|
None,
|
||||||
|
@ -937,8 +884,8 @@ impl KeyValueDatabase {
|
||||||
warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names");
|
warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names");
|
||||||
// Move old media files to new names
|
// Move old media files to new names
|
||||||
for (key, _) in db.mediaid_file.iter() {
|
for (key, _) in db.mediaid_file.iter() {
|
||||||
// we know that this method is deprecated, but we need to use it to migrate the old files
|
// we know that this method is deprecated, but we need to use it to migrate the
|
||||||
// to the new location
|
// old files to the new location
|
||||||
//
|
//
|
||||||
// TODO: remove this once we're sure that all users have migrated
|
// TODO: remove this once we're sure that all users have migrated
|
||||||
#[allow(deprecated)]
|
#[allow(deprecated)]
|
||||||
|
@ -957,7 +904,10 @@ impl KeyValueDatabase {
|
||||||
|
|
||||||
assert_eq!(
|
assert_eq!(
|
||||||
services().globals.database_version().unwrap(),
|
services().globals.database_version().unwrap(),
|
||||||
latest_database_version, "Failed asserting local database version {} is equal to known latest conduwuit database version {}", services().globals.database_version().unwrap(), latest_database_version
|
latest_database_version,
|
||||||
|
"Failed asserting local database version {} is equal to known latest conduwuit database version {}",
|
||||||
|
services().globals.database_version().unwrap(),
|
||||||
|
latest_database_version
|
||||||
);
|
);
|
||||||
|
|
||||||
{
|
{
|
||||||
|
@ -970,10 +920,7 @@ impl KeyValueDatabase {
|
||||||
warn!(
|
warn!(
|
||||||
"User {} matches the following forbidden username patterns: {}",
|
"User {} matches the following forbidden username patterns: {}",
|
||||||
user_id.to_string(),
|
user_id.to_string(),
|
||||||
matches
|
matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ")
|
||||||
.into_iter()
|
|
||||||
.map(|x| &patterns.patterns()[x])
|
|
||||||
.join(", ")
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -994,10 +941,7 @@ impl KeyValueDatabase {
|
||||||
"Room with alias {} ({}) matches the following forbidden room name patterns: {}",
|
"Room with alias {} ({}) matches the following forbidden room name patterns: {}",
|
||||||
room_alias,
|
room_alias,
|
||||||
&room_id,
|
&room_id,
|
||||||
matches
|
matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ")
|
||||||
.into_iter()
|
|
||||||
.map(|x| &patterns.patterns()[x])
|
|
||||||
.join(", ")
|
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -1011,9 +955,7 @@ impl KeyValueDatabase {
|
||||||
latest_database_version
|
latest_database_version
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
services()
|
services().globals.bump_database_version(latest_database_version)?;
|
||||||
.globals
|
|
||||||
.bump_database_version(latest_database_version)?;
|
|
||||||
|
|
||||||
// Create the admin room and server user on first run
|
// Create the admin room and server user on first run
|
||||||
services().admin.create_admin_room().await?;
|
services().admin.create_admin_room().await?;
|
||||||
|
@ -1031,16 +973,19 @@ impl KeyValueDatabase {
|
||||||
match set_emergency_access() {
|
match set_emergency_access() {
|
||||||
Ok(pwd_set) => {
|
Ok(pwd_set) => {
|
||||||
if pwd_set {
|
if pwd_set {
|
||||||
warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!");
|
warn!(
|
||||||
services().admin.send_message(RoomMessageEventContent::text_plain("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"));
|
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \
|
||||||
}
|
account recovery!"
|
||||||
}
|
|
||||||
Err(e) => {
|
|
||||||
error!(
|
|
||||||
"Could not set the configured emergency password for the conduit user: {}",
|
|
||||||
e
|
|
||||||
);
|
);
|
||||||
|
services().admin.send_message(RoomMessageEventContent::text_plain(
|
||||||
|
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \
|
||||||
|
account recovery!",
|
||||||
|
));
|
||||||
}
|
}
|
||||||
|
},
|
||||||
|
Err(e) => {
|
||||||
|
error!("Could not set the configured emergency password for the conduit user: {}", e);
|
||||||
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
services().sending.start_handler();
|
services().sending.start_handler();
|
||||||
|
@ -1079,12 +1024,8 @@ impl KeyValueDatabase {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn try_handle_updates() -> Result<()> {
|
async fn try_handle_updates() -> Result<()> {
|
||||||
let response = services()
|
let response =
|
||||||
.globals
|
services().globals.default_client().get("https://pupbrain.dev/check-for-updates/stable").send().await?;
|
||||||
.default_client()
|
|
||||||
.get("https://pupbrain.dev/check-for-updates/stable")
|
|
||||||
.send()
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
#[derive(Deserialize)]
|
#[derive(Deserialize)]
|
||||||
struct CheckForUpdatesResponseEntry {
|
struct CheckForUpdatesResponseEntry {
|
||||||
|
@ -1097,8 +1038,7 @@ impl KeyValueDatabase {
|
||||||
updates: Vec<CheckForUpdatesResponseEntry>,
|
updates: Vec<CheckForUpdatesResponseEntry>,
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?)
|
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?).map_err(|e| {
|
||||||
.map_err(|e| {
|
|
||||||
error!("Bad check for updates response: {e}");
|
error!("Bad check for updates response: {e}");
|
||||||
Error::BadServerResponse("Bad version check response")
|
Error::BadServerResponse("Bad version check response")
|
||||||
})?;
|
})?;
|
||||||
|
@ -1108,17 +1048,13 @@ impl KeyValueDatabase {
|
||||||
last_update_id = last_update_id.max(update.id);
|
last_update_id = last_update_id.max(update.id);
|
||||||
if update.id > services().globals.last_check_for_updates_id()? {
|
if update.id > services().globals.last_check_for_updates_id()? {
|
||||||
error!("{}", update.message);
|
error!("{}", update.message);
|
||||||
services()
|
services().admin.send_message(RoomMessageEventContent::text_plain(format!(
|
||||||
.admin
|
|
||||||
.send_message(RoomMessageEventContent::text_plain(format!(
|
|
||||||
"@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}",
|
"@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}",
|
||||||
update.date, update.message
|
update.date, update.message
|
||||||
)));
|
)));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
services()
|
services().globals.update_check_for_updates_id(last_update_id)?;
|
||||||
.globals
|
|
||||||
.update_check_for_updates_id(last_update_id)?;
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -1129,8 +1065,7 @@ impl KeyValueDatabase {
|
||||||
use tokio::signal::unix::{signal, SignalKind};
|
use tokio::signal::unix::{signal, SignalKind};
|
||||||
use tokio::time::Instant;
|
use tokio::time::Instant;
|
||||||
|
|
||||||
let timer_interval =
|
let timer_interval = Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval));
|
||||||
Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval));
|
|
||||||
|
|
||||||
fn perform_cleanup() {
|
fn perform_cleanup() {
|
||||||
let start = Instant::now();
|
let start = Instant::now();
|
||||||
|
@ -1176,9 +1111,7 @@ impl KeyValueDatabase {
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn start_presence_handler(
|
pub async fn start_presence_handler(presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>) {
|
||||||
presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>,
|
|
||||||
) {
|
|
||||||
tokio::spawn(async move {
|
tokio::spawn(async move {
|
||||||
match presence_handler(presence_timer_receiver).await {
|
match presence_handler(presence_timer_receiver).await {
|
||||||
Ok(()) => warn!("Presence maintenance task finished"),
|
Ok(()) => warn!("Presence maintenance task finished"),
|
||||||
|
@ -1188,15 +1121,13 @@ impl KeyValueDatabase {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Sets the emergency password and push rules for the @conduit account in case emergency password is set
|
/// Sets the emergency password and push rules for the @conduit account in case
|
||||||
|
/// emergency password is set
|
||||||
fn set_emergency_access() -> Result<bool> {
|
fn set_emergency_access() -> Result<bool> {
|
||||||
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
|
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
|
||||||
.expect("@conduit:server_name is a valid UserId");
|
.expect("@conduit:server_name is a valid UserId");
|
||||||
|
|
||||||
services().users.set_password(
|
services().users.set_password(&conduit_user, services().globals.emergency_password().as_deref())?;
|
||||||
&conduit_user,
|
|
||||||
services().globals.emergency_password().as_deref(),
|
|
||||||
)?;
|
|
||||||
|
|
||||||
let (ruleset, res) = match services().globals.emergency_password() {
|
let (ruleset, res) = match services().globals.emergency_password() {
|
||||||
Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)),
|
Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)),
|
||||||
|
@ -1208,7 +1139,9 @@ fn set_emergency_access() -> Result<bool> {
|
||||||
&conduit_user,
|
&conduit_user,
|
||||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||||
&serde_json::to_value(&GlobalAccountDataEvent {
|
&serde_json::to_value(&GlobalAccountDataEvent {
|
||||||
content: PushRulesEventContent { global: ruleset },
|
content: PushRulesEventContent {
|
||||||
|
global: ruleset,
|
||||||
|
},
|
||||||
})
|
})
|
||||||
.expect("to json value always works"),
|
.expect("to json value always works"),
|
||||||
)?;
|
)?;
|
||||||
|
|
|
@ -15,8 +15,5 @@ pub use utils::error::{Error, Result};
|
||||||
pub static SERVICES: RwLock<Option<&'static Services<'static>>> = RwLock::new(None);
|
pub static SERVICES: RwLock<Option<&'static Services<'static>>> = RwLock::new(None);
|
||||||
|
|
||||||
pub fn services() -> &'static Services<'static> {
|
pub fn services() -> &'static Services<'static> {
|
||||||
SERVICES
|
SERVICES.read().unwrap().expect("SERVICES should be initialized when this is called")
|
||||||
.read()
|
|
||||||
.unwrap()
|
|
||||||
.expect("SERVICES should be initialized when this is called")
|
|
||||||
}
|
}
|
||||||
|
|
266
src/main.rs
266
src/main.rs
|
@ -1,6 +1,6 @@
|
||||||
use std::{
|
use std::{
|
||||||
fs::Permissions, future::Future, io, net::SocketAddr, os::unix::fs::PermissionsExt, path::Path,
|
fs::Permissions, future::Future, io, net::SocketAddr, os::unix::fs::PermissionsExt, path::Path, sync::atomic,
|
||||||
sync::atomic, time::Duration,
|
time::Duration,
|
||||||
};
|
};
|
||||||
|
|
||||||
use axum::{
|
use axum::{
|
||||||
|
@ -10,7 +10,11 @@ use axum::{
|
||||||
Router,
|
Router,
|
||||||
};
|
};
|
||||||
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
|
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
|
||||||
|
#[cfg(feature = "axum_dual_protocol")]
|
||||||
|
use axum_server_dual_protocol::ServerExt;
|
||||||
|
use clap::Parser;
|
||||||
use conduit::api::{client_server, server_server};
|
use conduit::api::{client_server, server_server};
|
||||||
|
pub use conduit::*; // Re-export everything from the library crate
|
||||||
use either::Either::{Left, Right};
|
use either::Either::{Left, Right};
|
||||||
use figment::{
|
use figment::{
|
||||||
providers::{Env, Format, Toml},
|
providers::{Env, Format, Toml},
|
||||||
|
@ -29,7 +33,14 @@ use ruma::api::{
|
||||||
},
|
},
|
||||||
IncomingRequest,
|
IncomingRequest,
|
||||||
};
|
};
|
||||||
use tokio::{net::UnixListener, signal, sync::oneshot, task::JoinSet};
|
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
|
||||||
|
use tikv_jemallocator::Jemalloc;
|
||||||
|
use tokio::{
|
||||||
|
net::UnixListener,
|
||||||
|
signal,
|
||||||
|
sync::{oneshot, oneshot::Sender},
|
||||||
|
task::JoinSet,
|
||||||
|
};
|
||||||
use tower::ServiceBuilder;
|
use tower::ServiceBuilder;
|
||||||
use tower_http::{
|
use tower_http::{
|
||||||
cors::{self, CorsLayer},
|
cors::{self, CorsLayer},
|
||||||
|
@ -39,18 +50,6 @@ use tower_http::{
|
||||||
use tracing::{debug, error, info, warn, Level};
|
use tracing::{debug, error, info, warn, Level};
|
||||||
use tracing_subscriber::{prelude::*, EnvFilter};
|
use tracing_subscriber::{prelude::*, EnvFilter};
|
||||||
|
|
||||||
use tokio::sync::oneshot::Sender;
|
|
||||||
|
|
||||||
use clap::Parser;
|
|
||||||
|
|
||||||
#[cfg(feature = "axum_dual_protocol")]
|
|
||||||
use axum_server_dual_protocol::ServerExt;
|
|
||||||
|
|
||||||
pub use conduit::*; // Re-export everything from the library crate
|
|
||||||
|
|
||||||
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
|
|
||||||
use tikv_jemallocator::Jemalloc;
|
|
||||||
|
|
||||||
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
|
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
|
||||||
#[global_allocator]
|
#[global_allocator]
|
||||||
static GLOBAL: Jemalloc = Jemalloc;
|
static GLOBAL: Jemalloc = Jemalloc;
|
||||||
|
@ -67,7 +66,8 @@ async fn main() {
|
||||||
Figment::new()
|
Figment::new()
|
||||||
.merge(
|
.merge(
|
||||||
Toml::file(Env::var("CONDUIT_CONFIG").expect(
|
Toml::file(Env::var("CONDUIT_CONFIG").expect(
|
||||||
"The CONDUIT_CONFIG environment variable was set but appears to be invalid. This should be set to the path to a valid TOML file, an empty string (for compatibility), or removed/unset entirely.",
|
"The CONDUIT_CONFIG environment variable was set but appears to be invalid. This should be set to \
|
||||||
|
the path to a valid TOML file, an empty string (for compatibility), or removed/unset entirely.",
|
||||||
))
|
))
|
||||||
.nested(),
|
.nested(),
|
||||||
)
|
)
|
||||||
|
@ -81,7 +81,7 @@ async fn main() {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("It looks like your config is invalid. The following error occurred: {e}");
|
eprintln!("It looks like your config is invalid. The following error occurred: {e}");
|
||||||
return;
|
return;
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
if config.allow_jaeger {
|
if config.allow_jaeger {
|
||||||
|
@ -96,21 +96,16 @@ async fn main() {
|
||||||
let filter_layer = match EnvFilter::try_new(&config.log) {
|
let filter_layer = match EnvFilter::try_new(&config.log) {
|
||||||
Ok(s) => s,
|
Ok(s) => s,
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!(
|
eprintln!("It looks like your log config is invalid. The following error occurred: {e}");
|
||||||
"It looks like your log config is invalid. The following error occurred: {e}"
|
|
||||||
);
|
|
||||||
EnvFilter::try_new("warn").unwrap()
|
EnvFilter::try_new("warn").unwrap()
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let subscriber = tracing_subscriber::Registry::default()
|
let subscriber = tracing_subscriber::Registry::default().with(filter_layer).with(telemetry);
|
||||||
.with(filter_layer)
|
|
||||||
.with(telemetry);
|
|
||||||
tracing::subscriber::set_global_default(subscriber).unwrap();
|
tracing::subscriber::set_global_default(subscriber).unwrap();
|
||||||
} else if config.tracing_flame {
|
} else if config.tracing_flame {
|
||||||
let registry = tracing_subscriber::Registry::default();
|
let registry = tracing_subscriber::Registry::default();
|
||||||
let (flame_layer, _guard) =
|
let (flame_layer, _guard) = tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap();
|
||||||
tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap();
|
|
||||||
let flame_layer = flame_layer.with_empty_samples(false);
|
let flame_layer = flame_layer.with_empty_samples(false);
|
||||||
|
|
||||||
let filter_layer = EnvFilter::new("trace,h2=off");
|
let filter_layer = EnvFilter::new("trace,h2=off");
|
||||||
|
@ -125,7 +120,7 @@ async fn main() {
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}");
|
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}");
|
||||||
EnvFilter::try_new("warn").unwrap()
|
EnvFilter::try_new("warn").unwrap()
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let subscriber = registry.with(filter_layer).with(fmt_layer);
|
let subscriber = registry.with(filter_layer).with(fmt_layer);
|
||||||
|
@ -172,17 +167,34 @@ async fn main() {
|
||||||
if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists()
|
if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists()
|
||||||
/* Host */
|
/* Host */
|
||||||
{
|
{
|
||||||
error!("You are detected using OpenVZ with a loopback/localhost listening address of {}. If you are using OpenVZ for containers and you use NAT-based networking to communicate with the host and guest, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", config.address);
|
error!(
|
||||||
|
"You are detected using OpenVZ with a loopback/localhost listening address of {}. If you are using \
|
||||||
|
OpenVZ for containers and you use NAT-based networking to communicate with the host and guest, this \
|
||||||
|
will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.",
|
||||||
|
config.address
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
if Path::new("/.dockerenv").exists() {
|
if Path::new("/.dockerenv").exists() {
|
||||||
error!("You are detected using Docker with a loopback/localhost listening address of {}. If you are using a reverse proxy on the host and require communication to conduwuit in the Docker container via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", config.address);
|
error!(
|
||||||
|
"You are detected using Docker with a loopback/localhost listening address of {}. If you are using a \
|
||||||
|
reverse proxy on the host and require communication to conduwuit in the Docker container via \
|
||||||
|
NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \
|
||||||
|
you can ignore.",
|
||||||
|
config.address
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
if Path::new("/run/.containerenv").exists() {
|
if Path::new("/run/.containerenv").exists() {
|
||||||
error!("You are detected using Podman with a loopback/localhost listening address of {}. If you are using a reverse proxy on the host and require communication to conduwuit in the Podman container via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", config.address);
|
error!(
|
||||||
|
"You are detected using Podman with a loopback/localhost listening address of {}. If you are using a \
|
||||||
|
reverse proxy on the host and require communication to conduwuit in the Podman container via \
|
||||||
|
NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \
|
||||||
|
you can ignore.",
|
||||||
|
config.address
|
||||||
|
);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -208,18 +220,22 @@ async fn main() {
|
||||||
|
|
||||||
// check if user specified valid IP CIDR ranges on startup
|
// check if user specified valid IP CIDR ranges on startup
|
||||||
for cidr in services().globals.ip_range_denylist() {
|
for cidr in services().globals.ip_range_denylist() {
|
||||||
let _ = ipaddress::IPAddress::parse(cidr)
|
let _ = ipaddress::IPAddress::parse(cidr).map_err(|e| error!("Error parsing specified IP CIDR range: {e}"));
|
||||||
.map_err(|e| error!("Error parsing specified IP CIDR range: {e}"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.allow_registration
|
if config.allow_registration
|
||||||
&& !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse
|
&& !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse
|
||||||
&& config.registration_token.is_none()
|
&& config.registration_token.is_none()
|
||||||
{
|
{
|
||||||
error!("!! You have `allow_registration` enabled without a token configured in your config which means you are allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n
|
error!(
|
||||||
If this is not the intended behaviour, please set a registration token with the `registration_token` config option.\n
|
"!! You have `allow_registration` enabled without a token configured in your config which means you are \
|
||||||
For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour you want, please set the following config option to true:
|
allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n
|
||||||
`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`");
|
If this is not the intended behaviour, please set a registration token with the `registration_token` config \
|
||||||
|
option.\n
|
||||||
|
For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour \
|
||||||
|
you want, please set the following config option to true:
|
||||||
|
`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`"
|
||||||
|
);
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -227,8 +243,12 @@ async fn main() {
|
||||||
&& config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse
|
&& config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse
|
||||||
&& config.registration_token.is_none()
|
&& config.registration_token.is_none()
|
||||||
{
|
{
|
||||||
warn!("Open registration is enabled via setting `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` and `allow_registration` to true without a registration token configured. You are expected to be aware of the risks now.\n
|
warn!(
|
||||||
If this is not the desired behaviour, please set a registration token.");
|
"Open registration is enabled via setting \
|
||||||
|
`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` and `allow_registration` to \
|
||||||
|
true without a registration token configured. You are expected to be aware of the risks now.\n
|
||||||
|
If this is not the desired behaviour, please set a registration token."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.allow_outgoing_presence && !config.allow_local_presence {
|
if config.allow_outgoing_presence && !config.allow_local_presence {
|
||||||
|
@ -237,26 +257,33 @@ async fn main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.allow_outgoing_presence {
|
if config.allow_outgoing_presence {
|
||||||
warn!("! Outgoing federated presence is not spec compliant due to relying on PDUs and EDUs combined.\nOutgoing presence will not be very reliable due to this and any issues with federated outgoing presence are very likely attributed to this issue.\nIncoming presence and local presence are unaffected.");
|
warn!(
|
||||||
|
"! Outgoing federated presence is not spec compliant due to relying on PDUs and EDUs combined.\nOutgoing \
|
||||||
|
presence will not be very reliable due to this and any issues with federated outgoing presence are very \
|
||||||
|
likely attributed to this issue.\nIncoming presence and local presence are unaffected."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
if config
|
if config.url_preview_domain_contains_allowlist.contains(&"*".to_owned()) {
|
||||||
.url_preview_domain_contains_allowlist
|
warn!(
|
||||||
.contains(&"*".to_owned())
|
"All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". \
|
||||||
{
|
This opens up significant attack surface to your server. You are expected to be aware of the risks by \
|
||||||
warn!("All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". This opens up significant attack surface to your server. You are expected to be aware of the risks by doing this.");
|
doing this."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if config
|
if config.url_preview_domain_explicit_allowlist.contains(&"*".to_owned()) {
|
||||||
.url_preview_domain_explicit_allowlist
|
warn!(
|
||||||
.contains(&"*".to_owned())
|
"All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". \
|
||||||
{
|
This opens up significant attack surface to your server. You are expected to be aware of the risks by \
|
||||||
warn!("All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". This opens up significant attack surface to your server. You are expected to be aware of the risks by doing this.");
|
doing this."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
if config
|
if config.url_preview_url_contains_allowlist.contains(&"*".to_owned()) {
|
||||||
.url_preview_url_contains_allowlist
|
warn!(
|
||||||
.contains(&"*".to_owned())
|
"All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This \
|
||||||
{
|
opens up significant attack surface to your server. You are expected to be aware of the risks by doing \
|
||||||
warn!("All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This opens up significant attack surface to your server. You are expected to be aware of the risks by doing this.");
|
this."
|
||||||
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
/* end ad-hoc config validation/checks */
|
/* end ad-hoc config validation/checks */
|
||||||
|
@ -266,8 +293,9 @@ async fn main() {
|
||||||
error!("Critical error starting server: {}", e);
|
error!("Critical error starting server: {}", e);
|
||||||
};
|
};
|
||||||
|
|
||||||
// if server runs into critical error and shuts down, shut down the tracer provider if jaegar is used.
|
// if server runs into critical error and shuts down, shut down the tracer
|
||||||
// awaiting run_server() is a blocking call so putting this after is fine, but not the other options above.
|
// provider if jaegar is used. awaiting run_server() is a blocking call so
|
||||||
|
// putting this after is fine, but not the other options above.
|
||||||
if config.allow_jaeger {
|
if config.allow_jaeger {
|
||||||
opentelemetry::global::shutdown_tracer_provider();
|
opentelemetry::global::shutdown_tracer_provider();
|
||||||
}
|
}
|
||||||
|
@ -281,17 +309,9 @@ async fn run_server() -> io::Result<()> {
|
||||||
// Left is only 1 value, so make a vec with 1 value only
|
// Left is only 1 value, so make a vec with 1 value only
|
||||||
let port_vec = [port];
|
let port_vec = [port];
|
||||||
|
|
||||||
port_vec
|
port_vec.iter().copied().map(|port| SocketAddr::from((config.address, *port))).collect::<Vec<_>>()
|
||||||
.iter()
|
},
|
||||||
.copied()
|
Right(ports) => ports.iter().copied().map(|port| SocketAddr::from((config.address, port))).collect::<Vec<_>>(),
|
||||||
.map(|port| SocketAddr::from((config.address, *port)))
|
|
||||||
.collect::<Vec<_>>()
|
|
||||||
}
|
|
||||||
Right(ports) => ports
|
|
||||||
.iter()
|
|
||||||
.copied()
|
|
||||||
.map(|port| SocketAddr::from((config.address, port)))
|
|
||||||
.collect::<Vec<_>>(),
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let x_requested_with = HeaderName::from_static("x-requested-with");
|
let x_requested_with = HeaderName::from_static("x-requested-with");
|
||||||
|
@ -336,17 +356,12 @@ async fn run_server() -> io::Result<()> {
|
||||||
.max_age(Duration::from_secs(86400)),
|
.max_age(Duration::from_secs(86400)),
|
||||||
)
|
)
|
||||||
.layer(DefaultBodyLimit::max(
|
.layer(DefaultBodyLimit::max(
|
||||||
config
|
config.max_request_size.try_into().expect("failed to convert max request size"),
|
||||||
.max_request_size
|
|
||||||
.try_into()
|
|
||||||
.expect("failed to convert max request size"),
|
|
||||||
));
|
));
|
||||||
|
|
||||||
let app = if cfg!(feature = "zstd_compression") && config.zstd_compression {
|
let app = if cfg!(feature = "zstd_compression") && config.zstd_compression {
|
||||||
debug!("zstd body compression is enabled");
|
debug!("zstd body compression is enabled");
|
||||||
routes()
|
routes().layer(middlewares.compression()).into_make_service()
|
||||||
.layer(middlewares.compression())
|
|
||||||
.into_make_service()
|
|
||||||
} else {
|
} else {
|
||||||
routes().layer(middlewares).into_make_service()
|
routes().layer(middlewares).into_make_service()
|
||||||
};
|
};
|
||||||
|
@ -371,9 +386,7 @@ async fn run_server() -> io::Result<()> {
|
||||||
let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap();
|
let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap();
|
||||||
|
|
||||||
let listener = UnixListener::bind(path.clone())?;
|
let listener = UnixListener::bind(path.clone())?;
|
||||||
tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms))
|
tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms)).await.unwrap();
|
||||||
.await
|
|
||||||
.unwrap();
|
|
||||||
let socket = SocketIncoming::from_listener(listener);
|
let socket = SocketIncoming::from_listener(listener);
|
||||||
|
|
||||||
#[cfg(feature = "systemd")]
|
#[cfg(feature = "systemd")]
|
||||||
|
@ -395,12 +408,16 @@ async fn run_server() -> io::Result<()> {
|
||||||
"Using direct TLS. Certificate path {} and certificate private key path {}",
|
"Using direct TLS. Certificate path {} and certificate private key path {}",
|
||||||
&tls.certs, &tls.key
|
&tls.certs, &tls.key
|
||||||
);
|
);
|
||||||
info!("Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit directly with TLS.");
|
info!(
|
||||||
|
"Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit \
|
||||||
|
directly with TLS."
|
||||||
|
);
|
||||||
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
|
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
|
||||||
|
|
||||||
if cfg!(feature = "axum_dual_protocol") {
|
if cfg!(feature = "axum_dual_protocol") {
|
||||||
info!(
|
info!(
|
||||||
"conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This will only take affect if `dual_protocol` is enabled in `[global.tls]`"
|
"conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This \
|
||||||
|
will only take affect if `dual_protocol` is enabled in `[global.tls]`"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -418,11 +435,7 @@ async fn run_server() -> io::Result<()> {
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
for addr in &addrs {
|
for addr in &addrs {
|
||||||
join_set.spawn(
|
join_set.spawn(bind_rustls(*addr, conf.clone()).handle(handle.clone()).serve(app.clone()));
|
||||||
bind_rustls(*addr, conf.clone())
|
|
||||||
.handle(handle.clone())
|
|
||||||
.serve(app.clone()),
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -431,18 +444,16 @@ async fn run_server() -> io::Result<()> {
|
||||||
|
|
||||||
if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol {
|
if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol {
|
||||||
warn!(
|
warn!(
|
||||||
"Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too (insecure!)",
|
"Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too \
|
||||||
|
(insecure!)",
|
||||||
addrs, &tls.certs
|
addrs, &tls.certs
|
||||||
);
|
);
|
||||||
} else {
|
} else {
|
||||||
info!(
|
info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs);
|
||||||
"Listening on {:?} with TLS certificate {}",
|
|
||||||
addrs, &tls.certs
|
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
join_set.join_next().await;
|
join_set.join_next().await;
|
||||||
}
|
},
|
||||||
None => {
|
None => {
|
||||||
let mut join_set = JoinSet::new();
|
let mut join_set = JoinSet::new();
|
||||||
for addr in &addrs {
|
for addr in &addrs {
|
||||||
|
@ -454,7 +465,7 @@ async fn run_server() -> io::Result<()> {
|
||||||
|
|
||||||
info!("Listening on {:?}", addrs);
|
info!("Listening on {:?}", addrs);
|
||||||
join_set.join_next().await;
|
join_set.join_next().await;
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -462,20 +473,16 @@ async fn run_server() -> io::Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn spawn_task<B: Send + 'static>(
|
async fn spawn_task<B: Send + 'static>(
|
||||||
req: axum::http::Request<B>,
|
req: axum::http::Request<B>, next: axum::middleware::Next<B>,
|
||||||
next: axum::middleware::Next<B>,
|
|
||||||
) -> std::result::Result<axum::response::Response, StatusCode> {
|
) -> std::result::Result<axum::response::Response, StatusCode> {
|
||||||
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
||||||
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
||||||
}
|
}
|
||||||
tokio::spawn(next.run(req))
|
tokio::spawn(next.run(req)).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||||
.await
|
|
||||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn unrecognized_method<B: Send + 'static>(
|
async fn unrecognized_method<B: Send + 'static>(
|
||||||
req: axum::http::Request<B>,
|
req: axum::http::Request<B>, next: axum::middleware::Next<B>,
|
||||||
next: axum::middleware::Next<B>,
|
|
||||||
) -> std::result::Result<axum::response::Response, StatusCode> {
|
) -> std::result::Result<axum::response::Response, StatusCode> {
|
||||||
let method = req.method().clone();
|
let method = req.method().clone();
|
||||||
let uri = req.uri().clone();
|
let uri = req.uri().clone();
|
||||||
|
@ -637,10 +644,7 @@ fn routes() -> Router {
|
||||||
.ruma_route(client_server::get_relating_events_route)
|
.ruma_route(client_server::get_relating_events_route)
|
||||||
.ruma_route(client_server::get_hierarchy_route)
|
.ruma_route(client_server::get_hierarchy_route)
|
||||||
.ruma_route(server_server::get_server_version_route)
|
.ruma_route(server_server::get_server_version_route)
|
||||||
.route(
|
.route("/_matrix/key/v2/server", get(server_server::get_server_keys_route))
|
||||||
"/_matrix/key/v2/server",
|
|
||||||
get(server_server::get_server_keys_route),
|
|
||||||
)
|
|
||||||
.route(
|
.route(
|
||||||
"/_matrix/key/v2/server/:key_id",
|
"/_matrix/key/v2/server/:key_id",
|
||||||
get(server_server::get_server_keys_deprecated_route),
|
get(server_server::get_server_keys_deprecated_route),
|
||||||
|
@ -663,35 +667,18 @@ fn routes() -> Router {
|
||||||
.ruma_route(server_server::get_profile_information_route)
|
.ruma_route(server_server::get_profile_information_route)
|
||||||
.ruma_route(server_server::get_keys_route)
|
.ruma_route(server_server::get_keys_route)
|
||||||
.ruma_route(server_server::claim_keys_route)
|
.ruma_route(server_server::claim_keys_route)
|
||||||
.route(
|
.route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync))
|
||||||
"/_matrix/client/r0/rooms/:room_id/initialSync",
|
.route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync))
|
||||||
get(initial_sync),
|
.route("/client/server.json", get(client_server::syncv3_client_server_json))
|
||||||
)
|
.route("/.well-known/matrix/client", get(client_server::well_known_client_route))
|
||||||
.route(
|
.route("/.well-known/matrix/server", get(server_server::well_known_server_route))
|
||||||
"/_matrix/client/v3/rooms/:room_id/initialSync",
|
|
||||||
get(initial_sync),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/client/server.json",
|
|
||||||
get(client_server::syncv3_client_server_json),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/.well-known/matrix/client",
|
|
||||||
get(client_server::well_known_client_route),
|
|
||||||
)
|
|
||||||
.route(
|
|
||||||
"/.well-known/matrix/server",
|
|
||||||
get(server_server::well_known_server_route),
|
|
||||||
)
|
|
||||||
.route("/", get(it_works))
|
.route("/", get(it_works))
|
||||||
.fallback(not_found)
|
.fallback(not_found)
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> {
|
async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> {
|
||||||
let ctrl_c = async {
|
let ctrl_c = async {
|
||||||
signal::ctrl_c()
|
signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
|
||||||
.await
|
|
||||||
.expect("failed to install Ctrl+C handler");
|
|
||||||
};
|
};
|
||||||
|
|
||||||
#[cfg(unix)]
|
#[cfg(unix)]
|
||||||
|
@ -721,19 +708,23 @@ async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> {
|
||||||
#[cfg(feature = "systemd")]
|
#[cfg(feature = "systemd")]
|
||||||
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]);
|
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]);
|
||||||
|
|
||||||
tx.send(()).expect("failed sending shutdown transaction to oneshot channel (this is unlikely a conduwuit bug and more so your system may not be in an okay/ideal state.)");
|
tx.send(()).expect(
|
||||||
|
"failed sending shutdown transaction to oneshot channel (this is unlikely a conduwuit bug and more so your \
|
||||||
|
system may not be in an okay/ideal state.)",
|
||||||
|
);
|
||||||
|
|
||||||
if shutdown_time_elapsed.elapsed() >= Duration::from_secs(60) && cfg!(feature = "systemd") {
|
if shutdown_time_elapsed.elapsed() >= Duration::from_secs(60) && cfg!(feature = "systemd") {
|
||||||
warn!("Still shutting down after 60 seconds since receiving shutdown signal, asking systemd for more time (+120 seconds). Remaining connections: {}", handle.connection_count());
|
warn!(
|
||||||
|
"Still shutting down after 60 seconds since receiving shutdown signal, asking systemd for more time (+120 \
|
||||||
|
seconds). Remaining connections: {}",
|
||||||
|
handle.connection_count()
|
||||||
|
);
|
||||||
|
|
||||||
#[cfg(feature = "systemd")]
|
#[cfg(feature = "systemd")]
|
||||||
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::ExtendTimeoutUsec(120)]);
|
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::ExtendTimeoutUsec(120)]);
|
||||||
}
|
}
|
||||||
|
|
||||||
warn!(
|
warn!("Time took to shutdown: {:?} seconds", shutdown_time_elapsed.elapsed());
|
||||||
"Time took to shutdown: {:?} seconds",
|
|
||||||
shutdown_time_elapsed.elapsed()
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
|
@ -744,15 +735,10 @@ async fn not_found(uri: Uri) -> impl IntoResponse {
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn initial_sync(_uri: Uri) -> impl IntoResponse {
|
async fn initial_sync(_uri: Uri) -> impl IntoResponse {
|
||||||
Error::BadRequest(
|
Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented")
|
||||||
ErrorKind::GuestAccessForbidden,
|
|
||||||
"Guest access not implemented",
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn it_works() -> &'static str {
|
async fn it_works() -> &'static str { "hewwo from conduwuit woof!" }
|
||||||
"hewwo from conduwuit woof!"
|
|
||||||
}
|
|
||||||
|
|
||||||
trait RouterExt {
|
trait RouterExt {
|
||||||
fn ruma_route<H, T>(self, handler: H) -> Self
|
fn ruma_route<H, T>(self, handler: H) -> Self
|
||||||
|
@ -773,8 +759,8 @@ impl RouterExt for Router {
|
||||||
|
|
||||||
pub trait RumaHandler<T> {
|
pub trait RumaHandler<T> {
|
||||||
// Can't transform to a handler without boxing or relying on the nightly-only
|
// Can't transform to a handler without boxing or relying on the nightly-only
|
||||||
// impl-trait-in-traits feature. Moving a small amount of extra logic into the trait
|
// impl-trait-in-traits feature. Moving a small amount of extra logic into the
|
||||||
// allows bypassing both.
|
// trait allows bypassing both.
|
||||||
fn add_to_router(self, router: Router) -> Router;
|
fn add_to_router(self, router: Router) -> Router;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,35 +1,28 @@
|
||||||
use std::collections::HashMap;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
use crate::Result;
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||||
serde::Raw,
|
serde::Raw,
|
||||||
RoomId, UserId,
|
RoomId, UserId,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
pub trait Data: Send + Sync {
|
pub trait Data: Send + Sync {
|
||||||
/// Places one event in the account data of the user and removes the previous entry.
|
/// Places one event in the account data of the user and removes the
|
||||||
|
/// previous entry.
|
||||||
fn update(
|
fn update(
|
||||||
&self,
|
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||||
room_id: Option<&RoomId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
event_type: RoomAccountDataEventType,
|
|
||||||
data: &serde_json::Value,
|
data: &serde_json::Value,
|
||||||
) -> Result<()>;
|
) -> Result<()>;
|
||||||
|
|
||||||
/// Searches the account data for a specific kind.
|
/// Searches the account data for a specific kind.
|
||||||
fn get(
|
fn get(
|
||||||
&self,
|
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
|
||||||
room_id: Option<&RoomId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
kind: RoomAccountDataEventType,
|
|
||||||
) -> Result<Option<Box<serde_json::value::RawValue>>>;
|
) -> Result<Option<Box<serde_json::value::RawValue>>>;
|
||||||
|
|
||||||
/// Returns all changes to the account data that happened after `since`.
|
/// Returns all changes to the account data that happened after `since`.
|
||||||
fn changes_since(
|
fn changes_since(
|
||||||
&self,
|
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
|
||||||
room_id: Option<&RoomId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
since: u64,
|
|
||||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>;
|
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,15 +1,14 @@
|
||||||
mod data;
|
mod data;
|
||||||
|
|
||||||
pub(crate) use data::Data;
|
use std::collections::HashMap;
|
||||||
|
|
||||||
|
pub(crate) use data::Data;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||||
serde::Raw,
|
serde::Raw,
|
||||||
RoomId, UserId,
|
RoomId, UserId,
|
||||||
};
|
};
|
||||||
|
|
||||||
use std::collections::HashMap;
|
|
||||||
|
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
|
|
||||||
pub struct Service {
|
pub struct Service {
|
||||||
|
@ -17,13 +16,11 @@ pub struct Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Service {
|
impl Service {
|
||||||
/// Places one event in the account data of the user and removes the previous entry.
|
/// Places one event in the account data of the user and removes the
|
||||||
|
/// previous entry.
|
||||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||||
pub fn update(
|
pub fn update(
|
||||||
&self,
|
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||||
room_id: Option<&RoomId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
event_type: RoomAccountDataEventType,
|
|
||||||
data: &serde_json::Value,
|
data: &serde_json::Value,
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
self.db.update(room_id, user_id, event_type, data)
|
self.db.update(room_id, user_id, event_type, data)
|
||||||
|
@ -32,10 +29,7 @@ impl Service {
|
||||||
/// Searches the account data for a specific kind.
|
/// Searches the account data for a specific kind.
|
||||||
#[tracing::instrument(skip(self, room_id, user_id, event_type))]
|
#[tracing::instrument(skip(self, room_id, user_id, event_type))]
|
||||||
pub fn get(
|
pub fn get(
|
||||||
&self,
|
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||||
room_id: Option<&RoomId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
event_type: RoomAccountDataEventType,
|
|
||||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||||
self.db.get(room_id, user_id, event_type)
|
self.db.get(room_id, user_id, event_type)
|
||||||
}
|
}
|
||||||
|
@ -43,10 +37,7 @@ impl Service {
|
||||||
/// Returns all changes to the account data that happened after `since`.
|
/// Returns all changes to the account data that happened after `since`.
|
||||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||||
pub fn changes_since(
|
pub fn changes_since(
|
||||||
&self,
|
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
|
||||||
room_id: Option<&RoomId>,
|
|
||||||
user_id: &UserId,
|
|
||||||
since: u64,
|
|
||||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||||
self.db.changes_since(room_id, user_id, since)
|
self.db.changes_since(room_id, user_id, since)
|
||||||
}
|
}
|
||||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -11,9 +11,7 @@ pub struct Service {
|
||||||
|
|
||||||
impl Service {
|
impl Service {
|
||||||
/// Registers an appservice and returns the ID to the caller
|
/// Registers an appservice and returns the ID to the caller
|
||||||
pub fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
pub fn register_appservice(&self, yaml: Registration) -> Result<String> { self.db.register_appservice(yaml) }
|
||||||
self.db.register_appservice(yaml)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Remove an appservice registration
|
/// Remove an appservice registration
|
||||||
///
|
///
|
||||||
|
@ -24,15 +22,9 @@ impl Service {
|
||||||
self.db.unregister_appservice(service_name)
|
self.db.unregister_appservice(service_name)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> { self.db.get_registration(id) }
|
||||||
self.db.get_registration(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> {
|
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { self.db.iter_ids() }
|
||||||
self.db.iter_ids()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn all(&self) -> Result<Vec<(String, Registration)>> {
|
pub fn all(&self) -> Result<Vec<(String, Registration)>> { self.db.all() }
|
||||||
self.db.all()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -22,16 +22,12 @@ pub trait Data: Send + Sync {
|
||||||
fn load_keypair(&self) -> Result<Ed25519KeyPair>;
|
fn load_keypair(&self) -> Result<Ed25519KeyPair>;
|
||||||
fn remove_keypair(&self) -> Result<()>;
|
fn remove_keypair(&self) -> Result<()>;
|
||||||
fn add_signing_key(
|
fn add_signing_key(
|
||||||
&self,
|
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||||
origin: &ServerName,
|
|
||||||
new_keys: ServerSigningKeys,
|
|
||||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||||
|
|
||||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||||
fn signing_keys_for(
|
/// for the server.
|
||||||
&self,
|
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||||
origin: &ServerName,
|
|
||||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
|
||||||
fn database_version(&self) -> Result<u64>;
|
fn database_version(&self) -> Result<u64>;
|
||||||
fn bump_database_version(&self, new_version: u64) -> Result<()>;
|
fn bump_database_version(&self, new_version: u64) -> Result<()>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,6 +15,7 @@ use std::{
|
||||||
|
|
||||||
use argon2::Argon2;
|
use argon2::Argon2;
|
||||||
use base64::{engine::general_purpose, Engine as _};
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
|
pub use data::Data;
|
||||||
use futures_util::FutureExt;
|
use futures_util::FutureExt;
|
||||||
use hyper::{
|
use hyper::{
|
||||||
client::connect::dns::{GaiResolver, Name},
|
client::connect::dns::{GaiResolver, Name},
|
||||||
|
@ -27,21 +28,16 @@ use ruma::{
|
||||||
client::sync::sync_events,
|
client::sync::sync_events,
|
||||||
federation::discovery::{ServerSigningKeys, VerifyKey},
|
federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||||
},
|
},
|
||||||
DeviceId, RoomVersionId, ServerName, UserId,
|
serde::Base64,
|
||||||
};
|
DeviceId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId,
|
||||||
use ruma::{
|
RoomVersionId, ServerName, UserId,
|
||||||
serde::Base64, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName,
|
|
||||||
OwnedServerSigningKeyId, OwnedUserId,
|
|
||||||
};
|
};
|
||||||
use sha2::Digest;
|
use sha2::Digest;
|
||||||
use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore};
|
use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore};
|
||||||
use tracing::{error, info};
|
use tracing::{error, info};
|
||||||
use trust_dns_resolver::TokioAsyncResolver;
|
use trust_dns_resolver::TokioAsyncResolver;
|
||||||
|
|
||||||
pub use data::Data;
|
use crate::{api::server_server::FedDest, services, Config, Error, Result};
|
||||||
|
|
||||||
use crate::api::server_server::FedDest;
|
|
||||||
use crate::{services, Config, Error, Result};
|
|
||||||
|
|
||||||
mod data;
|
mod data;
|
||||||
|
|
||||||
|
@ -83,9 +79,11 @@ pub struct Service<'a> {
|
||||||
pub argon: Argon2<'a>,
|
pub argon: Argon2<'a>,
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
|
/// Handles "rotation" of long-polling requests. "Rotation" in this context is
|
||||||
|
/// similar to "rotation" of log files and the like.
|
||||||
///
|
///
|
||||||
/// This is utilized to have sync workers return early and release read locks on the database.
|
/// This is utilized to have sync workers return early and release read locks on
|
||||||
|
/// the database.
|
||||||
pub(crate) struct RotationHandler(broadcast::Sender<()>, ());
|
pub(crate) struct RotationHandler(broadcast::Sender<()>, ());
|
||||||
|
|
||||||
impl RotationHandler {
|
impl RotationHandler {
|
||||||
|
@ -102,15 +100,11 @@ impl RotationHandler {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn fire(&self) {
|
pub fn fire(&self) { let _ = self.0.send(()); }
|
||||||
let _ = self.0.send(());
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Default for RotationHandler {
|
impl Default for RotationHandler {
|
||||||
fn default() -> Self {
|
fn default() -> Self { Self::new() }
|
||||||
Self::new()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
struct Resolver {
|
struct Resolver {
|
||||||
|
@ -162,15 +156,13 @@ impl Service<'_> {
|
||||||
error!("Keypair invalid. Deleting...");
|
error!("Keypair invalid. Deleting...");
|
||||||
db.remove_keypair()?;
|
db.remove_keypair()?;
|
||||||
return Err(e);
|
return Err(e);
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new()));
|
let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new()));
|
||||||
|
|
||||||
let jwt_decoding_key = config
|
let jwt_decoding_key =
|
||||||
.jwt_secret
|
config.jwt_secret.as_ref().map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
|
||||||
.as_ref()
|
|
||||||
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
|
|
||||||
|
|
||||||
let url_preview_client = url_preview_reqwest_client_builder(&config)?.build()?;
|
let url_preview_client = url_preview_reqwest_client_builder(&config)?.build()?;
|
||||||
let default_client = reqwest_client_builder(&config)?.build()?;
|
let default_client = reqwest_client_builder(&config)?.build()?;
|
||||||
|
@ -205,10 +197,7 @@ impl Service<'_> {
|
||||||
config,
|
config,
|
||||||
keypair: Arc::new(keypair),
|
keypair: Arc::new(keypair),
|
||||||
dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| {
|
dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| {
|
||||||
error!(
|
error!("Failed to set up trust dns resolver with system config: {}", e);
|
||||||
"Failed to set up trust dns resolver with system config: {}",
|
|
||||||
e
|
|
||||||
);
|
|
||||||
Error::bad_config("Failed to set up trust dns resolver with system config.")
|
Error::bad_config("Failed to set up trust dns resolver with system config.")
|
||||||
})?,
|
})?,
|
||||||
actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())),
|
actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())),
|
||||||
|
@ -236,10 +225,7 @@ impl Service<'_> {
|
||||||
|
|
||||||
fs::create_dir_all(s.get_media_folder())?;
|
fs::create_dir_all(s.get_media_folder())?;
|
||||||
|
|
||||||
if !s
|
if !s.supported_room_versions().contains(&s.config.default_room_version) {
|
||||||
.supported_room_versions()
|
|
||||||
.contains(&s.config.default_room_version)
|
|
||||||
{
|
|
||||||
error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version");
|
error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version");
|
||||||
s.config.default_room_version = crate::config::default_default_room_version();
|
s.config.default_room_version = crate::config::default_default_room_version();
|
||||||
};
|
};
|
||||||
|
@ -248,12 +234,11 @@ impl Service<'_> {
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns this server's keypair.
|
/// Returns this server's keypair.
|
||||||
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair {
|
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair }
|
||||||
&self.keypair
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns a reqwest client which can be used to send requests for URL previews
|
/// Returns a reqwest client which can be used to send requests for URL
|
||||||
/// This is the same as `default_client()` except a redirect policy of max 2 is set
|
/// previews This is the same as `default_client()` except a redirect policy
|
||||||
|
/// of max 2 is set
|
||||||
pub fn url_preview_client(&self) -> reqwest::Client {
|
pub fn url_preview_client(&self) -> reqwest::Client {
|
||||||
// Client is cheap to clone (Arc wrapper) and avoids lifetime issues
|
// Client is cheap to clone (Arc wrapper) and avoids lifetime issues
|
||||||
self.url_preview_client.clone()
|
self.url_preview_client.clone()
|
||||||
|
@ -272,60 +257,36 @@ impl Service<'_> {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn next_count(&self) -> Result<u64> {
|
pub fn next_count(&self) -> Result<u64> { self.db.next_count() }
|
||||||
self.db.next_count()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn current_count(&self) -> Result<u64> {
|
pub fn current_count(&self) -> Result<u64> { self.db.current_count() }
|
||||||
self.db.current_count()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn last_check_for_updates_id(&self) -> Result<u64> {
|
pub fn last_check_for_updates_id(&self) -> Result<u64> { self.db.last_check_for_updates_id() }
|
||||||
self.db.last_check_for_updates_id()
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) }
|
||||||
self.db.update_check_for_updates_id(id)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||||
self.db.watch(user_id, device_id).await
|
self.db.watch(user_id, device_id).await
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn cleanup(&self) -> Result<()> {
|
pub fn cleanup(&self) -> Result<()> { self.db.cleanup() }
|
||||||
self.db.cleanup()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn server_name(&self) -> &ServerName {
|
pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() }
|
||||||
self.config.server_name.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn max_request_size(&self) -> u32 {
|
pub fn max_request_size(&self) -> u32 { self.config.max_request_size }
|
||||||
self.config.max_request_size
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn max_fetch_prev_events(&self) -> u16 {
|
pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events }
|
||||||
self.config.max_fetch_prev_events
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_registration(&self) -> bool {
|
pub fn allow_registration(&self) -> bool { self.config.allow_registration }
|
||||||
self.config.allow_registration
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_guest_registration(&self) -> bool {
|
pub fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration }
|
||||||
self.config.allow_guest_registration
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_encryption(&self) -> bool {
|
pub fn allow_encryption(&self) -> bool { self.config.allow_encryption }
|
||||||
self.config.allow_encryption
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_federation(&self) -> bool {
|
pub fn allow_federation(&self) -> bool { self.config.allow_federation }
|
||||||
self.config.allow_federation
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_public_room_directory_over_federation(&self) -> bool {
|
pub fn allow_public_room_directory_over_federation(&self) -> bool {
|
||||||
self.config.allow_public_room_directory_over_federation
|
self.config.allow_public_room_directory_over_federation
|
||||||
|
@ -335,73 +296,39 @@ impl Service<'_> {
|
||||||
self.config.allow_public_room_directory_without_auth
|
self.config.allow_public_room_directory_without_auth
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn allow_device_name_federation(&self) -> bool {
|
pub fn allow_device_name_federation(&self) -> bool { self.config.allow_device_name_federation }
|
||||||
self.config.allow_device_name_federation
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_room_creation(&self) -> bool {
|
pub fn allow_room_creation(&self) -> bool { self.config.allow_room_creation }
|
||||||
self.config.allow_room_creation
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_unstable_room_versions(&self) -> bool {
|
pub fn allow_unstable_room_versions(&self) -> bool { self.config.allow_unstable_room_versions }
|
||||||
self.config.allow_unstable_room_versions
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn default_room_version(&self) -> RoomVersionId {
|
pub fn default_room_version(&self) -> RoomVersionId { self.config.default_room_version.clone() }
|
||||||
self.config.default_room_version.clone()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn new_user_displayname_suffix(&self) -> &String {
|
pub fn new_user_displayname_suffix(&self) -> &String { &self.config.new_user_displayname_suffix }
|
||||||
&self.config.new_user_displayname_suffix
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_check_for_updates(&self) -> bool {
|
pub fn allow_check_for_updates(&self) -> bool { self.config.allow_check_for_updates }
|
||||||
self.config.allow_check_for_updates
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn trusted_servers(&self) -> &[OwnedServerName] {
|
pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers }
|
||||||
&self.config.trusted_servers
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn query_trusted_key_servers_first(&self) -> bool {
|
pub fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first }
|
||||||
self.config.query_trusted_key_servers_first
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn dns_resolver(&self) -> &TokioAsyncResolver {
|
pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.dns_resolver }
|
||||||
&self.dns_resolver
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> {
|
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() }
|
||||||
self.jwt_decoding_key.as_ref()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn turn_password(&self) -> &String {
|
pub fn turn_password(&self) -> &String { &self.config.turn_password }
|
||||||
&self.config.turn_password
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn turn_ttl(&self) -> u64 {
|
pub fn turn_ttl(&self) -> u64 { self.config.turn_ttl }
|
||||||
self.config.turn_ttl
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn turn_uris(&self) -> &[String] {
|
pub fn turn_uris(&self) -> &[String] { &self.config.turn_uris }
|
||||||
&self.config.turn_uris
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn turn_username(&self) -> &String {
|
pub fn turn_username(&self) -> &String { &self.config.turn_username }
|
||||||
&self.config.turn_username
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn turn_secret(&self) -> &String {
|
pub fn turn_secret(&self) -> &String { &self.config.turn_secret }
|
||||||
&self.config.turn_secret
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn notification_push_path(&self) -> &String {
|
pub fn notification_push_path(&self) -> &String { &self.config.notification_push_path }
|
||||||
&self.config.notification_push_path
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn emergency_password(&self) -> &Option<String> {
|
pub fn emergency_password(&self) -> &Option<String> { &self.config.emergency_password }
|
||||||
&self.config.emergency_password
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn url_preview_domain_contains_allowlist(&self) -> &Vec<String> {
|
pub fn url_preview_domain_contains_allowlist(&self) -> &Vec<String> {
|
||||||
&self.config.url_preview_domain_contains_allowlist
|
&self.config.url_preview_domain_contains_allowlist
|
||||||
|
@ -411,77 +338,41 @@ impl Service<'_> {
|
||||||
&self.config.url_preview_domain_explicit_allowlist
|
&self.config.url_preview_domain_explicit_allowlist
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn url_preview_url_contains_allowlist(&self) -> &Vec<String> {
|
pub fn url_preview_url_contains_allowlist(&self) -> &Vec<String> { &self.config.url_preview_url_contains_allowlist }
|
||||||
&self.config.url_preview_url_contains_allowlist
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn url_preview_max_spider_size(&self) -> usize {
|
pub fn url_preview_max_spider_size(&self) -> usize { self.config.url_preview_max_spider_size }
|
||||||
self.config.url_preview_max_spider_size
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn url_preview_check_root_domain(&self) -> bool {
|
pub fn url_preview_check_root_domain(&self) -> bool { self.config.url_preview_check_root_domain }
|
||||||
self.config.url_preview_check_root_domain
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forbidden_room_names(&self) -> &RegexSet {
|
pub fn forbidden_room_names(&self) -> &RegexSet { &self.config.forbidden_room_names }
|
||||||
&self.config.forbidden_room_names
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn forbidden_usernames(&self) -> &RegexSet {
|
pub fn forbidden_usernames(&self) -> &RegexSet { &self.config.forbidden_usernames }
|
||||||
&self.config.forbidden_usernames
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_local_presence(&self) -> bool {
|
pub fn allow_local_presence(&self) -> bool { self.config.allow_local_presence }
|
||||||
self.config.allow_local_presence
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_incoming_presence(&self) -> bool {
|
pub fn allow_incoming_presence(&self) -> bool { self.config.allow_incoming_presence }
|
||||||
self.config.allow_incoming_presence
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn allow_outgoing_presence(&self) -> bool {
|
pub fn allow_outgoing_presence(&self) -> bool { self.config.allow_outgoing_presence }
|
||||||
self.config.allow_outgoing_presence
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn presence_idle_timeout_s(&self) -> u64 {
|
pub fn presence_idle_timeout_s(&self) -> u64 { self.config.presence_idle_timeout_s }
|
||||||
self.config.presence_idle_timeout_s
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn presence_offline_timeout_s(&self) -> u64 {
|
pub fn presence_offline_timeout_s(&self) -> u64 { self.config.presence_offline_timeout_s }
|
||||||
self.config.presence_offline_timeout_s
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rocksdb_log_level(&self) -> &String {
|
pub fn rocksdb_log_level(&self) -> &String { &self.config.rocksdb_log_level }
|
||||||
&self.config.rocksdb_log_level
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rocksdb_max_log_file_size(&self) -> usize {
|
pub fn rocksdb_max_log_file_size(&self) -> usize { self.config.rocksdb_max_log_file_size }
|
||||||
self.config.rocksdb_max_log_file_size
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rocksdb_log_time_to_roll(&self) -> usize {
|
pub fn rocksdb_log_time_to_roll(&self) -> usize { self.config.rocksdb_log_time_to_roll }
|
||||||
self.config.rocksdb_log_time_to_roll
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rocksdb_optimize_for_spinning_disks(&self) -> bool {
|
pub fn rocksdb_optimize_for_spinning_disks(&self) -> bool { self.config.rocksdb_optimize_for_spinning_disks }
|
||||||
self.config.rocksdb_optimize_for_spinning_disks
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn rocksdb_parallelism_threads(&self) -> usize {
|
pub fn rocksdb_parallelism_threads(&self) -> usize { self.config.rocksdb_parallelism_threads }
|
||||||
self.config.rocksdb_parallelism_threads
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn prevent_media_downloads_from(&self) -> &[OwnedServerName] {
|
pub fn prevent_media_downloads_from(&self) -> &[OwnedServerName] { &self.config.prevent_media_downloads_from }
|
||||||
&self.config.prevent_media_downloads_from
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn ip_range_denylist(&self) -> &[String] {
|
pub fn ip_range_denylist(&self) -> &[String] { &self.config.ip_range_denylist }
|
||||||
&self.config.ip_range_denylist
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn block_non_admin_invites(&self) -> bool {
|
pub fn block_non_admin_invites(&self) -> bool { self.config.block_non_admin_invites }
|
||||||
self.config.block_non_admin_invites
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn supported_room_versions(&self) -> Vec<RoomVersionId> {
|
pub fn supported_room_versions(&self) -> Vec<RoomVersionId> {
|
||||||
let mut room_versions: Vec<RoomVersionId> = vec![];
|
let mut room_versions: Vec<RoomVersionId> = vec![];
|
||||||
|
@ -492,24 +383,22 @@ impl Service<'_> {
|
||||||
room_versions
|
room_versions
|
||||||
}
|
}
|
||||||
|
|
||||||
/// TODO: the key valid until timestamp (`valid_until_ts`) is only honored in room version > 4
|
/// TODO: the key valid until timestamp (`valid_until_ts`) is only honored
|
||||||
|
/// in room version > 4
|
||||||
///
|
///
|
||||||
/// Remove the outdated keys and insert the new ones.
|
/// Remove the outdated keys and insert the new ones.
|
||||||
///
|
///
|
||||||
/// This doesn't actually check that the keys provided are newer than the old set.
|
/// This doesn't actually check that the keys provided are newer than the
|
||||||
|
/// old set.
|
||||||
pub fn add_signing_key(
|
pub fn add_signing_key(
|
||||||
&self,
|
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||||
origin: &ServerName,
|
|
||||||
new_keys: ServerSigningKeys,
|
|
||||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||||
self.db.add_signing_key(origin, new_keys)
|
self.db.add_signing_key(origin, new_keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||||
pub fn signing_keys_for(
|
/// for the server.
|
||||||
&self,
|
pub fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||||
origin: &ServerName,
|
|
||||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
|
||||||
let mut keys = self.db.signing_keys_for(origin)?;
|
let mut keys = self.db.signing_keys_for(origin)?;
|
||||||
if origin == self.server_name() {
|
if origin == self.server_name() {
|
||||||
keys.insert(
|
keys.insert(
|
||||||
|
@ -525,13 +414,9 @@ impl Service<'_> {
|
||||||
Ok(keys)
|
Ok(keys)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn database_version(&self) -> Result<u64> {
|
pub fn database_version(&self) -> Result<u64> { self.db.database_version() }
|
||||||
self.db.database_version()
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
pub fn bump_database_version(&self, new_version: u64) -> Result<()> { self.db.bump_database_version(new_version) }
|
||||||
self.db.bump_database_version(new_version)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_media_folder(&self) -> PathBuf {
|
pub fn get_media_folder(&self) -> PathBuf {
|
||||||
let mut r = PathBuf::new();
|
let mut r = PathBuf::new();
|
||||||
|
@ -540,20 +425,23 @@ impl Service<'_> {
|
||||||
r
|
r
|
||||||
}
|
}
|
||||||
|
|
||||||
/// new SHA256 file name media function, requires "sha256_media" feature flag enabled and database migrated
|
/// new SHA256 file name media function, requires "sha256_media" feature
|
||||||
/// uses SHA256 hash of the base64 key as the file name
|
/// flag enabled and database migrated uses SHA256 hash of the base64 key as
|
||||||
|
/// the file name
|
||||||
pub fn get_media_file_new(&self, key: &[u8]) -> PathBuf {
|
pub fn get_media_file_new(&self, key: &[u8]) -> PathBuf {
|
||||||
let mut r = PathBuf::new();
|
let mut r = PathBuf::new();
|
||||||
r.push(self.config.database_path.clone());
|
r.push(self.config.database_path.clone());
|
||||||
r.push("media");
|
r.push("media");
|
||||||
// Using the hash of the base64 key as the filename
|
// Using the hash of the base64 key as the filename
|
||||||
// This is to prevent the total length of the path from exceeding the maximum length in most filesystems
|
// This is to prevent the total length of the path from exceeding the maximum
|
||||||
|
// length in most filesystems
|
||||||
r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key)));
|
r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key)));
|
||||||
r
|
r
|
||||||
}
|
}
|
||||||
|
|
||||||
/// old base64 file name media function
|
/// old base64 file name media function
|
||||||
/// This is the old version of `get_media_file` that uses the full base64 key as the filename.
|
/// This is the old version of `get_media_file` that uses the full base64
|
||||||
|
/// key as the filename.
|
||||||
///
|
///
|
||||||
/// This is deprecated and will be removed in a future release.
|
/// This is deprecated and will be removed in a future release.
|
||||||
/// Please use `get_media_file_new` instead.
|
/// Please use `get_media_file_new` instead.
|
||||||
|
@ -566,17 +454,11 @@ impl Service<'_> {
|
||||||
r
|
r
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn well_known_client(&self) -> &Option<String> {
|
pub fn well_known_client(&self) -> &Option<String> { &self.config.well_known_client }
|
||||||
&self.config.well_known_client
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn well_known_server(&self) -> &Option<String> {
|
pub fn well_known_server(&self) -> &Option<String> { &self.config.well_known_server }
|
||||||
&self.config.well_known_server
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn unix_socket_path(&self) -> &Option<PathBuf> {
|
pub fn unix_socket_path(&self) -> &Option<PathBuf> { &self.config.unix_socket_path }
|
||||||
&self.config.unix_socket_path
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn shutdown(&self) {
|
pub fn shutdown(&self) {
|
||||||
self.shutdown.store(true, atomic::Ordering::Relaxed);
|
self.shutdown.store(true, atomic::Ordering::Relaxed);
|
||||||
|
@ -586,7 +468,7 @@ impl Service<'_> {
|
||||||
match &self.unix_socket_path() {
|
match &self.unix_socket_path() {
|
||||||
Some(path) => {
|
Some(path) => {
|
||||||
std::fs::remove_file(path).unwrap();
|
std::fs::remove_file(path).unwrap();
|
||||||
}
|
},
|
||||||
None => error!(
|
None => error!(
|
||||||
"Unable to remove socket file at {:?} during shutdown.",
|
"Unable to remove socket file at {:?} during shutdown.",
|
||||||
&self.unix_socket_path()
|
&self.unix_socket_path()
|
||||||
|
@ -613,11 +495,7 @@ fn reqwest_client_builder(config: &Config) -> Result<reqwest::ClientBuilder> {
|
||||||
.connect_timeout(Duration::from_secs(60))
|
.connect_timeout(Duration::from_secs(60))
|
||||||
.timeout(Duration::from_secs(60 * 5))
|
.timeout(Duration::from_secs(60 * 5))
|
||||||
.redirect(redirect_policy)
|
.redirect(redirect_policy)
|
||||||
.user_agent(concat!(
|
.user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")));
|
||||||
env!("CARGO_PKG_NAME"),
|
|
||||||
"/",
|
|
||||||
env!("CARGO_PKG_VERSION")
|
|
||||||
));
|
|
||||||
|
|
||||||
if let Some(proxy) = config.proxy.to_proxy()? {
|
if let Some(proxy) = config.proxy.to_proxy()? {
|
||||||
reqwest_client_builder = reqwest_client_builder.proxy(proxy);
|
reqwest_client_builder = reqwest_client_builder.proxy(proxy);
|
||||||
|
@ -627,8 +505,10 @@ fn reqwest_client_builder(config: &Config) -> Result<reqwest::ClientBuilder> {
|
||||||
}
|
}
|
||||||
|
|
||||||
fn url_preview_reqwest_client_builder(config: &Config) -> Result<reqwest::ClientBuilder> {
|
fn url_preview_reqwest_client_builder(config: &Config) -> Result<reqwest::ClientBuilder> {
|
||||||
// for security reasons (e.g. malicious open redirect), we do not want to follow too many redirects when generating URL previews.
|
// for security reasons (e.g. malicious open redirect), we do not want to follow
|
||||||
// let's keep it at least 2 to account for HTTP -> HTTPS upgrades, if it becomes an issue we can consider raising it to 3.
|
// too many redirects when generating URL previews. let's keep it at least 2 to
|
||||||
|
// account for HTTP -> HTTPS upgrades, if it becomes an issue we can consider
|
||||||
|
// raising it to 3.
|
||||||
let redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
|
let redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
|
||||||
if attempt.previous().len() > 2 {
|
if attempt.previous().len() > 2 {
|
||||||
attempt.error("Too many redirects (max is 2)")
|
attempt.error("Too many redirects (max is 2)")
|
||||||
|
@ -642,11 +522,7 @@ fn url_preview_reqwest_client_builder(config: &Config) -> Result<reqwest::Client
|
||||||
.connect_timeout(Duration::from_secs(60))
|
.connect_timeout(Duration::from_secs(60))
|
||||||
.timeout(Duration::from_secs(60 * 5))
|
.timeout(Duration::from_secs(60 * 5))
|
||||||
.redirect(redirect_policy)
|
.redirect(redirect_policy)
|
||||||
.user_agent(concat!(
|
.user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")));
|
||||||
env!("CARGO_PKG_NAME"),
|
|
||||||
"/",
|
|
||||||
env!("CARGO_PKG_VERSION")
|
|
||||||
));
|
|
||||||
|
|
||||||
if let Some(proxy) = config.proxy.to_proxy()? {
|
if let Some(proxy) = config.proxy.to_proxy()? {
|
||||||
reqwest_client_builder = reqwest_client_builder.proxy(proxy);
|
reqwest_client_builder = reqwest_client_builder.proxy(proxy);
|
||||||
|
|
|
@ -1,78 +1,47 @@
|
||||||
use std::collections::BTreeMap;
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use crate::Result;
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||||
serde::Raw,
|
serde::Raw,
|
||||||
OwnedRoomId, RoomId, UserId,
|
OwnedRoomId, RoomId, UserId,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
pub trait Data: Send + Sync {
|
pub trait Data: Send + Sync {
|
||||||
fn create_backup(
|
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
backup_metadata: &Raw<BackupAlgorithm>,
|
|
||||||
) -> Result<String>;
|
|
||||||
|
|
||||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>;
|
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>;
|
||||||
|
|
||||||
fn update_backup(
|
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
backup_metadata: &Raw<BackupAlgorithm>,
|
|
||||||
) -> Result<String>;
|
|
||||||
|
|
||||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>;
|
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>;
|
||||||
|
|
||||||
fn get_latest_backup(&self, user_id: &UserId)
|
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>>;
|
||||||
-> Result<Option<(String, Raw<BackupAlgorithm>)>>;
|
|
||||||
|
|
||||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>;
|
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>;
|
||||||
|
|
||||||
fn add_key(
|
fn add_key(
|
||||||
&self,
|
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
session_id: &str,
|
|
||||||
key_data: &Raw<KeyBackupData>,
|
|
||||||
) -> Result<()>;
|
) -> Result<()>;
|
||||||
|
|
||||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize>;
|
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize>;
|
||||||
|
|
||||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String>;
|
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String>;
|
||||||
|
|
||||||
fn get_all(
|
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>>;
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>>;
|
|
||||||
|
|
||||||
fn get_room(
|
fn get_room(
|
||||||
&self,
|
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>>;
|
) -> Result<BTreeMap<String, Raw<KeyBackupData>>>;
|
||||||
|
|
||||||
fn get_session(
|
fn get_session(
|
||||||
&self,
|
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<Option<Raw<KeyBackupData>>>;
|
) -> Result<Option<Raw<KeyBackupData>>>;
|
||||||
|
|
||||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>;
|
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>;
|
||||||
|
|
||||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>;
|
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>;
|
||||||
|
|
||||||
fn delete_room_key(
|
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()>;
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<()>;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,24 +1,21 @@
|
||||||
mod data;
|
mod data;
|
||||||
pub(crate) use data::Data;
|
use std::collections::BTreeMap;
|
||||||
|
|
||||||
use crate::Result;
|
pub(crate) use data::Data;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||||
serde::Raw,
|
serde::Raw,
|
||||||
OwnedRoomId, RoomId, UserId,
|
OwnedRoomId, RoomId, UserId,
|
||||||
};
|
};
|
||||||
use std::collections::BTreeMap;
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
pub struct Service {
|
pub struct Service {
|
||||||
pub db: &'static dyn Data,
|
pub db: &'static dyn Data,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Service {
|
impl Service {
|
||||||
pub fn create_backup(
|
pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
backup_metadata: &Raw<BackupAlgorithm>,
|
|
||||||
) -> Result<String> {
|
|
||||||
self.db.create_backup(user_id, backup_metadata)
|
self.db.create_backup(user_id, backup_metadata)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -27,10 +24,7 @@ impl Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn update_backup(
|
pub fn update_backup(
|
||||||
&self,
|
&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>,
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
backup_metadata: &Raw<BackupAlgorithm>,
|
|
||||||
) -> Result<String> {
|
) -> Result<String> {
|
||||||
self.db.update_backup(user_id, version, backup_metadata)
|
self.db.update_backup(user_id, version, backup_metadata)
|
||||||
}
|
}
|
||||||
|
@ -39,64 +33,36 @@ impl Service {
|
||||||
self.db.get_latest_backup_version(user_id)
|
self.db.get_latest_backup_version(user_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_latest_backup(
|
pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
|
||||||
self.db.get_latest_backup(user_id)
|
self.db.get_latest_backup(user_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_backup(
|
pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
) -> Result<Option<Raw<BackupAlgorithm>>> {
|
|
||||||
self.db.get_backup(user_id, version)
|
self.db.get_backup(user_id, version)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn add_key(
|
pub fn add_key(
|
||||||
&self,
|
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
session_id: &str,
|
|
||||||
key_data: &Raw<KeyBackupData>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
self.db
|
self.db.add_key(user_id, version, room_id, session_id, key_data)
|
||||||
.add_key(user_id, version, room_id, session_id, key_data)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { self.db.count_keys(user_id, version) }
|
||||||
self.db.count_keys(user_id, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { self.db.get_etag(user_id, version) }
|
||||||
self.db.get_etag(user_id, version)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_all(
|
pub fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
|
||||||
self.db.get_all(user_id, version)
|
self.db.get_all(user_id, version)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_room(
|
pub fn get_room(
|
||||||
&self,
|
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||||
self.db.get_room(user_id, version, room_id)
|
self.db.get_room(user_id, version, room_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_session(
|
pub fn get_session(
|
||||||
&self,
|
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||||
self.db.get_session(user_id, version, room_id, session_id)
|
self.db.get_session(user_id, version, room_id, session_id)
|
||||||
}
|
}
|
||||||
|
@ -105,23 +71,11 @@ impl Service {
|
||||||
self.db.delete_all_keys(user_id, version)
|
self.db.delete_all_keys(user_id, version)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn delete_room_keys(
|
pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Result<()> {
|
|
||||||
self.db.delete_room_keys(user_id, version, room_id)
|
self.db.delete_room_keys(user_id, version, room_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn delete_room_key(
|
pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
|
||||||
&self,
|
self.db.delete_room_key(user_id, version, room_id, session_id)
|
||||||
user_id: &UserId,
|
|
||||||
version: &str,
|
|
||||||
room_id: &RoomId,
|
|
||||||
session_id: &str,
|
|
||||||
) -> Result<()> {
|
|
||||||
self.db
|
|
||||||
.delete_room_key(user_id, version, room_id, session_id)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,22 +2,14 @@ use crate::Result;
|
||||||
|
|
||||||
pub trait Data: Send + Sync {
|
pub trait Data: Send + Sync {
|
||||||
fn create_file_metadata(
|
fn create_file_metadata(
|
||||||
&self,
|
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
|
||||||
mxc: String,
|
|
||||||
width: u32,
|
|
||||||
height: u32,
|
|
||||||
content_disposition: Option<&str>,
|
|
||||||
content_type: Option<&str>,
|
|
||||||
) -> Result<Vec<u8>>;
|
) -> Result<Vec<u8>>;
|
||||||
|
|
||||||
fn delete_file_mxc(&self, mxc: String) -> Result<()>;
|
fn delete_file_mxc(&self, mxc: String) -> Result<()>;
|
||||||
|
|
||||||
/// Returns content_disposition, content_type and the metadata key.
|
/// Returns content_disposition, content_type and the metadata key.
|
||||||
fn search_file_metadata(
|
fn search_file_metadata(
|
||||||
&self,
|
&self, mxc: String, width: u32, height: u32,
|
||||||
mxc: String,
|
|
||||||
width: u32,
|
|
||||||
height: u32,
|
|
||||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)>;
|
) -> Result<(Option<String>, Option<String>, Vec<u8>)>;
|
||||||
|
|
||||||
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>>;
|
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>>;
|
||||||
|
@ -26,12 +18,7 @@ pub trait Data: Send + Sync {
|
||||||
|
|
||||||
fn remove_url_preview(&self, url: &str) -> Result<()>;
|
fn remove_url_preview(&self, url: &str) -> Result<()>;
|
||||||
|
|
||||||
fn set_url_preview(
|
fn set_url_preview(&self, url: &str, data: &super::UrlPreviewData, timestamp: std::time::Duration) -> Result<()>;
|
||||||
&self,
|
|
||||||
url: &str,
|
|
||||||
data: &super::UrlPreviewData,
|
|
||||||
timestamp: std::time::Duration,
|
|
||||||
) -> Result<()>;
|
|
||||||
|
|
||||||
fn get_url_preview(&self, url: &str) -> Option<super::UrlPreviewData>;
|
fn get_url_preview(&self, url: &str) -> Option<super::UrlPreviewData>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -7,18 +7,17 @@ use std::{
|
||||||
};
|
};
|
||||||
|
|
||||||
pub(crate) use data::Data;
|
pub(crate) use data::Data;
|
||||||
|
use image::imageops::FilterType;
|
||||||
use ruma::OwnedMxcUri;
|
use ruma::OwnedMxcUri;
|
||||||
use serde::Serialize;
|
use serde::Serialize;
|
||||||
use tracing::{debug, error};
|
|
||||||
|
|
||||||
use crate::{services, utils, Error, Result};
|
|
||||||
use image::imageops::FilterType;
|
|
||||||
|
|
||||||
use tokio::{
|
use tokio::{
|
||||||
fs::{self, File},
|
fs::{self, File},
|
||||||
io::{AsyncReadExt, AsyncWriteExt, BufReader},
|
io::{AsyncReadExt, AsyncWriteExt, BufReader},
|
||||||
sync::Mutex,
|
sync::Mutex,
|
||||||
};
|
};
|
||||||
|
use tracing::{debug, error};
|
||||||
|
|
||||||
|
use crate::{services, utils, Error, Result};
|
||||||
|
|
||||||
#[derive(Debug)]
|
#[derive(Debug)]
|
||||||
pub struct FileMeta {
|
pub struct FileMeta {
|
||||||
|
@ -29,35 +28,17 @@ pub struct FileMeta {
|
||||||
|
|
||||||
#[derive(Serialize, Default)]
|
#[derive(Serialize, Default)]
|
||||||
pub struct UrlPreviewData {
|
pub struct UrlPreviewData {
|
||||||
#[serde(
|
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:title"))]
|
||||||
skip_serializing_if = "Option::is_none",
|
|
||||||
rename(serialize = "og:title")
|
|
||||||
)]
|
|
||||||
pub title: Option<String>,
|
pub title: Option<String>,
|
||||||
#[serde(
|
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:description"))]
|
||||||
skip_serializing_if = "Option::is_none",
|
|
||||||
rename(serialize = "og:description")
|
|
||||||
)]
|
|
||||||
pub description: Option<String>,
|
pub description: Option<String>,
|
||||||
#[serde(
|
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image"))]
|
||||||
skip_serializing_if = "Option::is_none",
|
|
||||||
rename(serialize = "og:image")
|
|
||||||
)]
|
|
||||||
pub image: Option<String>,
|
pub image: Option<String>,
|
||||||
#[serde(
|
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:image:size"))]
|
||||||
skip_serializing_if = "Option::is_none",
|
|
||||||
rename(serialize = "matrix:image:size")
|
|
||||||
)]
|
|
||||||
pub image_size: Option<usize>,
|
pub image_size: Option<usize>,
|
||||||
#[serde(
|
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:width"))]
|
||||||
skip_serializing_if = "Option::is_none",
|
|
||||||
rename(serialize = "og:image:width")
|
|
||||||
)]
|
|
||||||
pub image_width: Option<u32>,
|
pub image_width: Option<u32>,
|
||||||
#[serde(
|
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:height"))]
|
||||||
skip_serializing_if = "Option::is_none",
|
|
||||||
rename(serialize = "og:image:height")
|
|
||||||
)]
|
|
||||||
pub image_height: Option<u32>,
|
pub image_height: Option<u32>,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -69,16 +50,10 @@ pub struct Service {
|
||||||
impl Service {
|
impl Service {
|
||||||
/// Uploads a file.
|
/// Uploads a file.
|
||||||
pub async fn create(
|
pub async fn create(
|
||||||
&self,
|
&self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, file: &[u8],
|
||||||
mxc: String,
|
|
||||||
content_disposition: Option<&str>,
|
|
||||||
content_type: Option<&str>,
|
|
||||||
file: &[u8],
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
// Width, Height = 0 if it's not a thumbnail
|
// Width, Height = 0 if it's not a thumbnail
|
||||||
let key = self
|
let key = self.db.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
|
||||||
.db
|
|
||||||
.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
|
|
||||||
|
|
||||||
let path = if cfg!(feature = "sha256_media") {
|
let path = if cfg!(feature = "sha256_media") {
|
||||||
services().globals.get_media_file_new(&key)
|
services().globals.get_media_file_new(&key)
|
||||||
|
@ -104,10 +79,7 @@ impl Service {
|
||||||
};
|
};
|
||||||
debug!("Got local file path: {:?}", file_path);
|
debug!("Got local file path: {:?}", file_path);
|
||||||
|
|
||||||
debug!(
|
debug!("Deleting local file {:?} from filesystem, original MXC: {}", file_path, mxc);
|
||||||
"Deleting local file {:?} from filesystem, original MXC: {}",
|
|
||||||
file_path, mxc
|
|
||||||
);
|
|
||||||
tokio::fs::remove_file(file_path).await?;
|
tokio::fs::remove_file(file_path).await?;
|
||||||
|
|
||||||
debug!("Deleting MXC {mxc} from database");
|
debug!("Deleting MXC {mxc} from database");
|
||||||
|
@ -117,23 +89,18 @@ impl Service {
|
||||||
Ok(())
|
Ok(())
|
||||||
} else {
|
} else {
|
||||||
error!("Failed to find any media keys for MXC \"{mxc}\" in our database (MXC does not exist)");
|
error!("Failed to find any media keys for MXC \"{mxc}\" in our database (MXC does not exist)");
|
||||||
Err(Error::bad_database("Failed to find any media keys for the provided MXC in our database (MXC does not exist)"))
|
Err(Error::bad_database(
|
||||||
|
"Failed to find any media keys for the provided MXC in our database (MXC does not exist)",
|
||||||
|
))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Uploads or replaces a file thumbnail.
|
/// Uploads or replaces a file thumbnail.
|
||||||
pub async fn upload_thumbnail(
|
pub async fn upload_thumbnail(
|
||||||
&self,
|
&self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, width: u32, height: u32,
|
||||||
mxc: String,
|
|
||||||
content_disposition: Option<&str>,
|
|
||||||
content_type: Option<&str>,
|
|
||||||
width: u32,
|
|
||||||
height: u32,
|
|
||||||
file: &[u8],
|
file: &[u8],
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let key =
|
let key = self.db.create_file_metadata(mxc, width, height, content_disposition, content_type)?;
|
||||||
self.db
|
|
||||||
.create_file_metadata(mxc, width, height, content_disposition, content_type)?;
|
|
||||||
|
|
||||||
let path = if cfg!(feature = "sha256_media") {
|
let path = if cfg!(feature = "sha256_media") {
|
||||||
services().globals.get_media_file_new(&key)
|
services().globals.get_media_file_new(&key)
|
||||||
|
@ -150,9 +117,7 @@ impl Service {
|
||||||
|
|
||||||
/// Downloads a file.
|
/// Downloads a file.
|
||||||
pub async fn get(&self, mxc: String) -> Result<Option<FileMeta>> {
|
pub async fn get(&self, mxc: String) -> Result<Option<FileMeta>> {
|
||||||
if let Ok((content_disposition, content_type, key)) =
|
if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) {
|
||||||
self.db.search_file_metadata(mxc, 0, 0)
|
|
||||||
{
|
|
||||||
let path = if cfg!(feature = "sha256_media") {
|
let path = if cfg!(feature = "sha256_media") {
|
||||||
services().globals.get_media_file_new(&key)
|
services().globals.get_media_file_new(&key)
|
||||||
} else {
|
} else {
|
||||||
|
@ -161,9 +126,7 @@ impl Service {
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut file = Vec::new();
|
let mut file = Vec::new();
|
||||||
BufReader::new(File::open(path).await?)
|
BufReader::new(File::open(path).await?).read_to_end(&mut file).await?;
|
||||||
.read_to_end(&mut file)
|
|
||||||
.await?;
|
|
||||||
|
|
||||||
Ok(Some(FileMeta {
|
Ok(Some(FileMeta {
|
||||||
content_disposition,
|
content_disposition,
|
||||||
|
@ -175,8 +138,8 @@ impl Service {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Deletes all remote only media files in the given at or after time/duration. Returns a u32
|
/// Deletes all remote only media files in the given at or after
|
||||||
/// with the amount of media files deleted.
|
/// time/duration. Returns a u32 with the amount of media files deleted.
|
||||||
pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result<u32> {
|
pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result<u32> {
|
||||||
if let Ok(all_keys) = self.db.get_all_media_keys() {
|
if let Ok(all_keys) = self.db.get_all_media_keys() {
|
||||||
let user_duration: SystemTime = match cyborgtime::parse_duration(&time) {
|
let user_duration: SystemTime = match cyborgtime::parse_duration(&time) {
|
||||||
|
@ -184,13 +147,11 @@ impl Service {
|
||||||
debug!("Parsed duration: {:?}", duration);
|
debug!("Parsed duration: {:?}", duration);
|
||||||
debug!("System time now: {:?}", SystemTime::now());
|
debug!("System time now: {:?}", SystemTime::now());
|
||||||
SystemTime::now() - duration
|
SystemTime::now() - duration
|
||||||
}
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
error!("Failed to parse user-specified time duration: {}", e);
|
error!("Failed to parse user-specified time duration: {}", e);
|
||||||
return Err(Error::bad_database(
|
return Err(Error::bad_database("Failed to parse user-specified time duration."));
|
||||||
"Failed to parse user-specified time duration.",
|
},
|
||||||
));
|
|
||||||
}
|
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut remote_mxcs: Vec<String> = vec![];
|
let mut remote_mxcs: Vec<String> = vec![];
|
||||||
|
@ -198,17 +159,16 @@ impl Service {
|
||||||
for key in all_keys {
|
for key in all_keys {
|
||||||
debug!("Full MXC key from database: {:?}", key);
|
debug!("Full MXC key from database: {:?}", key);
|
||||||
|
|
||||||
// we need to get the MXC URL from the first part of the key (the first 0xff / 255 push)
|
// we need to get the MXC URL from the first part of the key (the first 0xff /
|
||||||
// this code does look kinda crazy but blame conduit for using magic keys
|
// 255 push) this code does look kinda crazy but blame conduit for using magic
|
||||||
let mut parts = key.split(|&b| b == 0xff);
|
// keys
|
||||||
|
let mut parts = key.split(|&b| b == 0xFF);
|
||||||
let mxc = parts
|
let mxc = parts
|
||||||
.next()
|
.next()
|
||||||
.map(|bytes| {
|
.map(|bytes| {
|
||||||
utils::string_from_bytes(bytes).map_err(|e| {
|
utils::string_from_bytes(bytes).map_err(|e| {
|
||||||
error!("Failed to parse MXC unicode bytes from our database: {}", e);
|
error!("Failed to parse MXC unicode bytes from our database: {}", e);
|
||||||
Error::bad_database(
|
Error::bad_database("Failed to parse MXC unicode bytes from our database")
|
||||||
"Failed to parse MXC unicode bytes from our database",
|
|
||||||
)
|
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
.transpose()?;
|
.transpose()?;
|
||||||
|
@ -219,7 +179,7 @@ impl Service {
|
||||||
return Err(Error::bad_database(
|
return Err(Error::bad_database(
|
||||||
"Parsed MXC URL unicode bytes from database but still is None",
|
"Parsed MXC URL unicode bytes from database but still is None",
|
||||||
));
|
));
|
||||||
}
|
},
|
||||||
};
|
};
|
||||||
|
|
||||||
debug!("Parsed MXC key to URL: {}", mxc_s);
|
debug!("Parsed MXC key to URL: {}", mxc_s);
|
||||||
|
@ -252,12 +212,13 @@ impl Service {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
debug!("Finished going through all our media in database for eligible keys to delete, checking if these are empty");
|
debug!(
|
||||||
|
"Finished going through all our media in database for eligible keys to delete, checking if these are \
|
||||||
|
empty"
|
||||||
|
);
|
||||||
|
|
||||||
if remote_mxcs.is_empty() {
|
if remote_mxcs.is_empty() {
|
||||||
return Err(Error::bad_database(
|
return Err(Error::bad_database("Did not found any eligible MXCs to delete."));
|
||||||
"Did not found any eligible MXCs to delete.",
|
|
||||||
));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
debug!("Deleting media now in the past \"{:?}\".", user_duration);
|
debug!("Deleting media now in the past \"{:?}\".", user_duration);
|
||||||
|
@ -278,8 +239,8 @@ impl Service {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Returns width, height of the thumbnail and whether it should be cropped. Returns None when
|
/// Returns width, height of the thumbnail and whether it should be cropped.
|
||||||
/// the server should send the original file.
|
/// Returns None when the server should send the original file.
|
||||||
pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
|
pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
|
||||||
match (width, height) {
|
match (width, height) {
|
||||||
(0..=32, 0..=32) => Some((32, 32, true)),
|
(0..=32, 0..=32) => Some((32, 32, true)),
|
||||||
|
@ -296,24 +257,18 @@ impl Service {
|
||||||
/// Here's an example on how it works:
|
/// Here's an example on how it works:
|
||||||
///
|
///
|
||||||
/// - Client requests an image with width=567, height=567
|
/// - Client requests an image with width=567, height=567
|
||||||
/// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails
|
/// - Server rounds that up to (800, 600), so it doesn't have to save too
|
||||||
/// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96)
|
/// many thumbnails
|
||||||
|
/// - Server rounds that up again to (958, 600) to fix the aspect ratio
|
||||||
|
/// (only for width,height>96)
|
||||||
/// - Server creates the thumbnail and sends it to the user
|
/// - Server creates the thumbnail and sends it to the user
|
||||||
///
|
///
|
||||||
/// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards.
|
/// For width,height <= 96 the server uses another thumbnailing algorithm
|
||||||
pub async fn get_thumbnail(
|
/// which crops the image afterwards.
|
||||||
&self,
|
pub async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> {
|
||||||
mxc: String,
|
let (width, height, crop) = self.thumbnail_properties(width, height).unwrap_or((0, 0, false)); // 0, 0 because that's the original file
|
||||||
width: u32,
|
|
||||||
height: u32,
|
|
||||||
) -> Result<Option<FileMeta>> {
|
|
||||||
let (width, height, crop) = self
|
|
||||||
.thumbnail_properties(width, height)
|
|
||||||
.unwrap_or((0, 0, false)); // 0, 0 because that's the original file
|
|
||||||
|
|
||||||
if let Ok((content_disposition, content_type, key)) =
|
if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), width, height) {
|
||||||
self.db.search_file_metadata(mxc.clone(), width, height)
|
|
||||||
{
|
|
||||||
// Using saved thumbnail
|
// Using saved thumbnail
|
||||||
let path = if cfg!(feature = "sha256_media") {
|
let path = if cfg!(feature = "sha256_media") {
|
||||||
services().globals.get_media_file_new(&key)
|
services().globals.get_media_file_new(&key)
|
||||||
|
@ -330,9 +285,7 @@ impl Service {
|
||||||
content_type,
|
content_type,
|
||||||
file: file.clone(),
|
file: file.clone(),
|
||||||
}))
|
}))
|
||||||
} else if let Ok((content_disposition, content_type, key)) =
|
} else if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), 0, 0) {
|
||||||
self.db.search_file_metadata(mxc.clone(), 0, 0)
|
|
||||||
{
|
|
||||||
// Generate a thumbnail
|
// Generate a thumbnail
|
||||||
let path = if cfg!(feature = "sha256_media") {
|
let path = if cfg!(feature = "sha256_media") {
|
||||||
services().globals.get_media_file_new(&key)
|
services().globals.get_media_file_new(&key)
|
||||||
|
@ -365,19 +318,16 @@ impl Service {
|
||||||
|
|
||||||
let use_width = nratio <= ratio;
|
let use_width = nratio <= ratio;
|
||||||
let intermediate = if use_width {
|
let intermediate = if use_width {
|
||||||
u64::from(original_height) * u64::from(width)
|
u64::from(original_height) * u64::from(width) / u64::from(original_width)
|
||||||
/ u64::from(original_width)
|
|
||||||
} else {
|
} else {
|
||||||
u64::from(original_width) * u64::from(height)
|
u64::from(original_width) * u64::from(height) / u64::from(original_height)
|
||||||
/ u64::from(original_height)
|
|
||||||
};
|
};
|
||||||
if use_width {
|
if use_width {
|
||||||
if intermediate <= u64::from(::std::u32::MAX) {
|
if intermediate <= u64::from(::std::u32::MAX) {
|
||||||
(width, intermediate as u32)
|
(width, intermediate as u32)
|
||||||
} else {
|
} else {
|
||||||
(
|
(
|
||||||
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate)
|
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate) as u32,
|
||||||
as u32,
|
|
||||||
::std::u32::MAX,
|
::std::u32::MAX,
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
@ -386,8 +336,7 @@ impl Service {
|
||||||
} else {
|
} else {
|
||||||
(
|
(
|
||||||
::std::u32::MAX,
|
::std::u32::MAX,
|
||||||
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate)
|
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate) as u32,
|
||||||
as u32,
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
@ -396,10 +345,7 @@ impl Service {
|
||||||
};
|
};
|
||||||
|
|
||||||
let mut thumbnail_bytes = Vec::new();
|
let mut thumbnail_bytes = Vec::new();
|
||||||
thumbnail.write_to(
|
thumbnail.write_to(&mut Cursor::new(&mut thumbnail_bytes), image::ImageOutputFormat::Png)?;
|
||||||
&mut Cursor::new(&mut thumbnail_bytes),
|
|
||||||
image::ImageOutputFormat::Png,
|
|
||||||
)?;
|
|
||||||
|
|
||||||
// Save thumbnail in database so we don't have to generate it again next time
|
// Save thumbnail in database so we don't have to generate it again next time
|
||||||
let thumbnail_key = self.db.create_file_metadata(
|
let thumbnail_key = self.db.create_file_metadata(
|
||||||
|
@ -438,9 +384,7 @@ impl Service {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
|
pub async fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { self.db.get_url_preview(url) }
|
||||||
self.db.get_url_preview(url)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub async fn remove_url_preview(&self, url: &str) -> Result<()> {
|
pub async fn remove_url_preview(&self, url: &str) -> Result<()> {
|
||||||
// TODO: also remove the downloaded image
|
// TODO: also remove the downloaded image
|
||||||
|
@ -448,9 +392,7 @@ impl Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
|
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
|
||||||
let now = SystemTime::now()
|
let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).expect("valid system time");
|
||||||
.duration_since(SystemTime::UNIX_EPOCH)
|
|
||||||
.expect("valid system time");
|
|
||||||
self.db.set_url_preview(url, data, now)
|
self.db.set_url_preview(url, data, now)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -459,9 +401,8 @@ impl Service {
|
||||||
mod tests {
|
mod tests {
|
||||||
use std::path::PathBuf;
|
use std::path::PathBuf;
|
||||||
|
|
||||||
use sha2::Digest;
|
|
||||||
|
|
||||||
use base64::{engine::general_purpose, Engine as _};
|
use base64::{engine::general_purpose, Engine as _};
|
||||||
|
use sha2::Digest;
|
||||||
|
|
||||||
use super::*;
|
use super::*;
|
||||||
|
|
||||||
|
@ -469,73 +410,40 @@ mod tests {
|
||||||
|
|
||||||
impl Data for MockedKVDatabase {
|
impl Data for MockedKVDatabase {
|
||||||
fn create_file_metadata(
|
fn create_file_metadata(
|
||||||
&self,
|
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
|
||||||
mxc: String,
|
|
||||||
width: u32,
|
|
||||||
height: u32,
|
|
||||||
content_disposition: Option<&str>,
|
|
||||||
content_type: Option<&str>,
|
|
||||||
) -> Result<Vec<u8>> {
|
) -> Result<Vec<u8>> {
|
||||||
// copied from src/database/key_value/media.rs
|
// copied from src/database/key_value/media.rs
|
||||||
let mut key = mxc.as_bytes().to_vec();
|
let mut key = mxc.as_bytes().to_vec();
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(&width.to_be_bytes());
|
key.extend_from_slice(&width.to_be_bytes());
|
||||||
key.extend_from_slice(&height.to_be_bytes());
|
key.extend_from_slice(&height.to_be_bytes());
|
||||||
key.push(0xff);
|
key.push(0xFF);
|
||||||
key.extend_from_slice(
|
key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default());
|
||||||
content_disposition
|
key.push(0xFF);
|
||||||
.as_ref()
|
key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default());
|
||||||
.map(|f| f.as_bytes())
|
|
||||||
.unwrap_or_default(),
|
|
||||||
);
|
|
||||||
key.push(0xff);
|
|
||||||
key.extend_from_slice(
|
|
||||||
content_type
|
|
||||||
.as_ref()
|
|
||||||
.map(|c| c.as_bytes())
|
|
||||||
.unwrap_or_default(),
|
|
||||||
);
|
|
||||||
|
|
||||||
Ok(key)
|
Ok(key)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn delete_file_mxc(&self, _mxc: String) -> Result<()> {
|
fn delete_file_mxc(&self, _mxc: String) -> Result<()> { todo!() }
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn search_mxc_metadata_prefix(&self, _mxc: String) -> Result<Vec<Vec<u8>>> {
|
fn search_mxc_metadata_prefix(&self, _mxc: String) -> Result<Vec<Vec<u8>>> { todo!() }
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
|
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> { todo!() }
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn search_file_metadata(
|
fn search_file_metadata(
|
||||||
&self,
|
&self, _mxc: String, _width: u32, _height: u32,
|
||||||
_mxc: String,
|
|
||||||
_width: u32,
|
|
||||||
_height: u32,
|
|
||||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn remove_url_preview(&self, _url: &str) -> Result<()> {
|
fn remove_url_preview(&self, _url: &str) -> Result<()> { todo!() }
|
||||||
|
|
||||||
|
fn set_url_preview(&self, _url: &str, _data: &UrlPreviewData, _timestamp: std::time::Duration) -> Result<()> {
|
||||||
todo!()
|
todo!()
|
||||||
}
|
}
|
||||||
|
|
||||||
fn set_url_preview(
|
fn get_url_preview(&self, _url: &str) -> Option<UrlPreviewData> { todo!() }
|
||||||
&self,
|
|
||||||
_url: &str,
|
|
||||||
_data: &UrlPreviewData,
|
|
||||||
_timestamp: std::time::Duration,
|
|
||||||
) -> Result<()> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn get_url_preview(&self, _url: &str) -> Option<UrlPreviewData> {
|
|
||||||
todo!()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::test]
|
#[tokio::test]
|
||||||
|
@ -549,18 +457,11 @@ mod tests {
|
||||||
let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned();
|
let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned();
|
||||||
let width = 100;
|
let width = 100;
|
||||||
let height = 100;
|
let height = 100;
|
||||||
let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special characters like äöüß and even emoji like 🦀.png\"";
|
let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special \
|
||||||
|
characters like äöüß and even emoji like 🦀.png\"";
|
||||||
let content_type = "image/png";
|
let content_type = "image/png";
|
||||||
let key = media
|
let key =
|
||||||
.db
|
media.db.create_file_metadata(mxc, width, height, Some(content_disposition), Some(content_type)).unwrap();
|
||||||
.create_file_metadata(
|
|
||||||
mxc,
|
|
||||||
width,
|
|
||||||
height,
|
|
||||||
Some(content_disposition),
|
|
||||||
Some(content_type),
|
|
||||||
)
|
|
||||||
.unwrap();
|
|
||||||
let mut r = PathBuf::new();
|
let mut r = PathBuf::new();
|
||||||
r.push("/tmp");
|
r.push("/tmp");
|
||||||
r.push("media");
|
r.push("media");
|
||||||
|
|
|
@ -51,32 +51,59 @@ impl Services<'_> {
|
||||||
+ sending::Data
|
+ sending::Data
|
||||||
+ 'static,
|
+ 'static,
|
||||||
>(
|
>(
|
||||||
db: &'static D,
|
db: &'static D, config: Config,
|
||||||
config: Config,
|
|
||||||
) -> Result<Self> {
|
) -> Result<Self> {
|
||||||
Ok(Self {
|
Ok(Self {
|
||||||
appservice: appservice::Service { db },
|
appservice: appservice::Service {
|
||||||
pusher: pusher::Service { db },
|
db,
|
||||||
|
},
|
||||||
|
pusher: pusher::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
rooms: rooms::Service {
|
rooms: rooms::Service {
|
||||||
alias: rooms::alias::Service { db },
|
alias: rooms::alias::Service {
|
||||||
auth_chain: rooms::auth_chain::Service { db },
|
db,
|
||||||
directory: rooms::directory::Service { db },
|
},
|
||||||
|
auth_chain: rooms::auth_chain::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
|
directory: rooms::directory::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
edus: rooms::edus::Service {
|
edus: rooms::edus::Service {
|
||||||
presence: rooms::edus::presence::Service { db },
|
presence: rooms::edus::presence::Service {
|
||||||
read_receipt: rooms::edus::read_receipt::Service { db },
|
db,
|
||||||
typing: rooms::edus::typing::Service { db },
|
},
|
||||||
|
read_receipt: rooms::edus::read_receipt::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
|
typing: rooms::edus::typing::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
event_handler: rooms::event_handler::Service,
|
event_handler: rooms::event_handler::Service,
|
||||||
lazy_loading: rooms::lazy_loading::Service {
|
lazy_loading: rooms::lazy_loading::Service {
|
||||||
db,
|
db,
|
||||||
lazy_load_waiting: Mutex::new(HashMap::new()),
|
lazy_load_waiting: Mutex::new(HashMap::new()),
|
||||||
},
|
},
|
||||||
metadata: rooms::metadata::Service { db },
|
metadata: rooms::metadata::Service {
|
||||||
outlier: rooms::outlier::Service { db },
|
db,
|
||||||
pdu_metadata: rooms::pdu_metadata::Service { db },
|
},
|
||||||
search: rooms::search::Service { db },
|
outlier: rooms::outlier::Service {
|
||||||
short: rooms::short::Service { db },
|
db,
|
||||||
state: rooms::state::Service { db },
|
},
|
||||||
|
pdu_metadata: rooms::pdu_metadata::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
|
search: rooms::search::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
|
short: rooms::short::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
|
state: rooms::state::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
state_accessor: rooms::state_accessor::Service {
|
state_accessor: rooms::state_accessor::Service {
|
||||||
db,
|
db,
|
||||||
server_visibility_cache: Mutex::new(LruCache::new(
|
server_visibility_cache: Mutex::new(LruCache::new(
|
||||||
|
@ -86,7 +113,9 @@ impl Services<'_> {
|
||||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||||
)),
|
)),
|
||||||
},
|
},
|
||||||
state_cache: rooms::state_cache::Service { db },
|
state_cache: rooms::state_cache::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
state_compressor: rooms::state_compressor::Service {
|
state_compressor: rooms::state_compressor::Service {
|
||||||
db,
|
db,
|
||||||
stateinfo_cache: Mutex::new(LruCache::new(
|
stateinfo_cache: Mutex::new(LruCache::new(
|
||||||
|
@ -97,23 +126,35 @@ impl Services<'_> {
|
||||||
db,
|
db,
|
||||||
lasttimelinecount_cache: Mutex::new(HashMap::new()),
|
lasttimelinecount_cache: Mutex::new(HashMap::new()),
|
||||||
},
|
},
|
||||||
threads: rooms::threads::Service { db },
|
threads: rooms::threads::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
spaces: rooms::spaces::Service {
|
spaces: rooms::spaces::Service {
|
||||||
roomid_spacechunk_cache: Mutex::new(LruCache::new(
|
roomid_spacechunk_cache: Mutex::new(LruCache::new(
|
||||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||||
)),
|
)),
|
||||||
},
|
},
|
||||||
user: rooms::user::Service { db },
|
user: rooms::user::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
transaction_ids: transaction_ids::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
|
uiaa: uiaa::Service {
|
||||||
|
db,
|
||||||
},
|
},
|
||||||
transaction_ids: transaction_ids::Service { db },
|
|
||||||
uiaa: uiaa::Service { db },
|
|
||||||
users: users::Service {
|
users: users::Service {
|
||||||
db,
|
db,
|
||||||
connections: Mutex::new(BTreeMap::new()),
|
connections: Mutex::new(BTreeMap::new()),
|
||||||
},
|
},
|
||||||
account_data: account_data::Service { db },
|
account_data: account_data::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
admin: admin::Service::build(),
|
admin: admin::Service::build(),
|
||||||
key_backups: key_backups::Service { db },
|
key_backups: key_backups::Service {
|
||||||
|
db,
|
||||||
|
},
|
||||||
media: media::Service {
|
media: media::Service {
|
||||||
db,
|
db,
|
||||||
url_preview_mutex: RwLock::new(HashMap::new()),
|
url_preview_mutex: RwLock::new(HashMap::new()),
|
||||||
|
@ -123,49 +164,14 @@ impl Services<'_> {
|
||||||
globals: globals::Service::load(db, config)?,
|
globals: globals::Service::load(db, config)?,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
fn memory_usage(&self) -> String {
|
fn memory_usage(&self) -> String {
|
||||||
let lazy_load_waiting = self
|
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().len();
|
||||||
.rooms
|
let server_visibility_cache = self.rooms.state_accessor.server_visibility_cache.lock().unwrap().len();
|
||||||
.lazy_loading
|
let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len();
|
||||||
.lazy_load_waiting
|
let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len();
|
||||||
.lock()
|
let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().len();
|
||||||
.unwrap()
|
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().len();
|
||||||
.len();
|
|
||||||
let server_visibility_cache = self
|
|
||||||
.rooms
|
|
||||||
.state_accessor
|
|
||||||
.server_visibility_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.len();
|
|
||||||
let user_visibility_cache = self
|
|
||||||
.rooms
|
|
||||||
.state_accessor
|
|
||||||
.user_visibility_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.len();
|
|
||||||
let stateinfo_cache = self
|
|
||||||
.rooms
|
|
||||||
.state_compressor
|
|
||||||
.stateinfo_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.len();
|
|
||||||
let lasttimelinecount_cache = self
|
|
||||||
.rooms
|
|
||||||
.timeline
|
|
||||||
.lasttimelinecount_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.len();
|
|
||||||
let roomid_spacechunk_cache = self
|
|
||||||
.rooms
|
|
||||||
.spaces
|
|
||||||
.roomid_spacechunk_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.len();
|
|
||||||
|
|
||||||
format!(
|
format!(
|
||||||
"\
|
"\
|
||||||
|
@ -174,58 +180,28 @@ server_visibility_cache: {server_visibility_cache}
|
||||||
user_visibility_cache: {user_visibility_cache}
|
user_visibility_cache: {user_visibility_cache}
|
||||||
stateinfo_cache: {stateinfo_cache}
|
stateinfo_cache: {stateinfo_cache}
|
||||||
lasttimelinecount_cache: {lasttimelinecount_cache}
|
lasttimelinecount_cache: {lasttimelinecount_cache}
|
||||||
roomid_spacechunk_cache: {roomid_spacechunk_cache}\
|
roomid_spacechunk_cache: {roomid_spacechunk_cache}"
|
||||||
"
|
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
fn clear_caches(&self, amount: u32) {
|
fn clear_caches(&self, amount: u32) {
|
||||||
if amount > 0 {
|
if amount > 0 {
|
||||||
self.rooms
|
self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().clear();
|
||||||
.lazy_loading
|
|
||||||
.lazy_load_waiting
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.clear();
|
|
||||||
}
|
}
|
||||||
if amount > 1 {
|
if amount > 1 {
|
||||||
self.rooms
|
self.rooms.state_accessor.server_visibility_cache.lock().unwrap().clear();
|
||||||
.state_accessor
|
|
||||||
.server_visibility_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.clear();
|
|
||||||
}
|
}
|
||||||
if amount > 2 {
|
if amount > 2 {
|
||||||
self.rooms
|
self.rooms.state_accessor.user_visibility_cache.lock().unwrap().clear();
|
||||||
.state_accessor
|
|
||||||
.user_visibility_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.clear();
|
|
||||||
}
|
}
|
||||||
if amount > 3 {
|
if amount > 3 {
|
||||||
self.rooms
|
self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear();
|
||||||
.state_compressor
|
|
||||||
.stateinfo_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.clear();
|
|
||||||
}
|
}
|
||||||
if amount > 4 {
|
if amount > 4 {
|
||||||
self.rooms
|
self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().clear();
|
||||||
.timeline
|
|
||||||
.lasttimelinecount_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.clear();
|
|
||||||
}
|
}
|
||||||
if amount > 5 {
|
if amount > 5 {
|
||||||
self.rooms
|
self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().clear();
|
||||||
.spaces
|
|
||||||
.roomid_spacechunk_cache
|
|
||||||
.lock()
|
|
||||||
.unwrap()
|
|
||||||
.clear();
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,23 +1,25 @@
|
||||||
use crate::Error;
|
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
|
||||||
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
canonical_json::redact_content_in_place,
|
canonical_json::redact_content_in_place,
|
||||||
events::{
|
events::{
|
||||||
room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent,
|
room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent, AnyEphemeralRoomEvent,
|
||||||
AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent,
|
AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent,
|
||||||
AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType,
|
AnyTimelineEvent, StateEvent, TimelineEventType,
|
||||||
},
|
},
|
||||||
serde::Raw,
|
serde::Raw,
|
||||||
state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch,
|
state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId,
|
||||||
OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, UInt, UserId,
|
OwnedUserId, RoomId, RoomVersionId, UInt, UserId,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
use serde_json::{
|
use serde_json::{
|
||||||
json,
|
json,
|
||||||
value::{to_raw_value, RawValue as RawJsonValue},
|
value::{to_raw_value, RawValue as RawJsonValue},
|
||||||
};
|
};
|
||||||
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
|
|
||||||
use tracing::warn;
|
use tracing::warn;
|
||||||
|
|
||||||
|
use crate::Error;
|
||||||
|
|
||||||
/// Content hashes of a PDU.
|
/// Content hashes of a PDU.
|
||||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||||
pub struct EventHash {
|
pub struct EventHash {
|
||||||
|
@ -50,11 +52,7 @@ pub struct PduEvent {
|
||||||
|
|
||||||
impl PduEvent {
|
impl PduEvent {
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn redact(
|
pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &PduEvent) -> crate::Result<()> {
|
||||||
&mut self,
|
|
||||||
room_version_id: RoomVersionId,
|
|
||||||
reason: &PduEvent,
|
|
||||||
) -> crate::Result<()> {
|
|
||||||
self.unsigned = None;
|
self.unsigned = None;
|
||||||
|
|
||||||
let mut content = serde_json::from_str(self.content.get())
|
let mut content = serde_json::from_str(self.content.get())
|
||||||
|
@ -62,9 +60,12 @@ impl PduEvent {
|
||||||
redact_content_in_place(&mut content, &room_version_id, self.kind.to_string())
|
redact_content_in_place(&mut content, &room_version_id, self.kind.to_string())
|
||||||
.map_err(|e| Error::RedactionError(self.sender.server_name().to_owned(), e))?;
|
.map_err(|e| Error::RedactionError(self.sender.server_name().to_owned(), e))?;
|
||||||
|
|
||||||
self.unsigned = Some(to_raw_value(&json!({
|
self.unsigned = Some(
|
||||||
|
to_raw_value(&json!({
|
||||||
"redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works")
|
"redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works")
|
||||||
})).expect("to string always works"));
|
}))
|
||||||
|
.expect("to string always works"),
|
||||||
|
);
|
||||||
|
|
||||||
self.content = to_raw_value(&content).expect("to string always works");
|
self.content = to_raw_value(&content).expect("to string always works");
|
||||||
|
|
||||||
|
@ -73,8 +74,7 @@ impl PduEvent {
|
||||||
|
|
||||||
pub fn remove_transaction_id(&mut self) -> crate::Result<()> {
|
pub fn remove_transaction_id(&mut self) -> crate::Result<()> {
|
||||||
if let Some(unsigned) = &self.unsigned {
|
if let Some(unsigned) = &self.unsigned {
|
||||||
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> =
|
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = serde_json::from_str(unsigned.get())
|
||||||
serde_json::from_str(unsigned.get())
|
|
||||||
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
|
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
|
||||||
unsigned.remove("transaction_id");
|
unsigned.remove("transaction_id");
|
||||||
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
|
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
|
||||||
|
@ -276,36 +276,25 @@ impl PduEvent {
|
||||||
|
|
||||||
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
|
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
|
||||||
#[tracing::instrument]
|
#[tracing::instrument]
|
||||||
pub fn convert_to_outgoing_federation_event(
|
pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
|
||||||
mut pdu_json: CanonicalJsonObject,
|
if let Some(unsigned) = pdu_json.get_mut("unsigned").and_then(|val| val.as_object_mut()) {
|
||||||
) -> Box<RawJsonValue> {
|
|
||||||
if let Some(unsigned) = pdu_json
|
|
||||||
.get_mut("unsigned")
|
|
||||||
.and_then(|val| val.as_object_mut())
|
|
||||||
{
|
|
||||||
unsigned.remove("transaction_id");
|
unsigned.remove("transaction_id");
|
||||||
}
|
}
|
||||||
|
|
||||||
pdu_json.remove("event_id");
|
pdu_json.remove("event_id");
|
||||||
|
|
||||||
// TODO: another option would be to convert it to a canonical string to validate size
|
// TODO: another option would be to convert it to a canonical string to validate
|
||||||
// and return a Result<Raw<...>>
|
// size and return a Result<Raw<...>>
|
||||||
// serde_json::from_str::<Raw<_>>(
|
// serde_json::from_str::<Raw<_>>(
|
||||||
// ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is valid serde_json::Value"),
|
// ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is
|
||||||
// )
|
// valid serde_json::Value"), )
|
||||||
// .expect("Raw::from_value always works")
|
// .expect("Raw::from_value always works")
|
||||||
|
|
||||||
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
|
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_id_val(
|
pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result<Self, serde_json::Error> {
|
||||||
event_id: &EventId,
|
json.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
|
||||||
mut json: CanonicalJsonObject,
|
|
||||||
) -> Result<Self, serde_json::Error> {
|
|
||||||
json.insert(
|
|
||||||
"event_id".to_owned(),
|
|
||||||
CanonicalJsonValue::String(event_id.as_str().to_owned()),
|
|
||||||
);
|
|
||||||
|
|
||||||
serde_json::from_value(serde_json::to_value(json).expect("valid JSON"))
|
serde_json::from_value(serde_json::to_value(json).expect("valid JSON"))
|
||||||
}
|
}
|
||||||
|
@ -314,72 +303,46 @@ impl PduEvent {
|
||||||
impl state_res::Event for PduEvent {
|
impl state_res::Event for PduEvent {
|
||||||
type Id = Arc<EventId>;
|
type Id = Arc<EventId>;
|
||||||
|
|
||||||
fn event_id(&self) -> &Self::Id {
|
fn event_id(&self) -> &Self::Id { &self.event_id }
|
||||||
&self.event_id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn room_id(&self) -> &RoomId {
|
fn room_id(&self) -> &RoomId { &self.room_id }
|
||||||
&self.room_id
|
|
||||||
}
|
|
||||||
|
|
||||||
fn sender(&self) -> &UserId {
|
fn sender(&self) -> &UserId { &self.sender }
|
||||||
&self.sender
|
|
||||||
}
|
|
||||||
|
|
||||||
fn event_type(&self) -> &TimelineEventType {
|
fn event_type(&self) -> &TimelineEventType { &self.kind }
|
||||||
&self.kind
|
|
||||||
}
|
|
||||||
|
|
||||||
fn content(&self) -> &RawJsonValue {
|
fn content(&self) -> &RawJsonValue { &self.content }
|
||||||
&self.content
|
|
||||||
}
|
|
||||||
|
|
||||||
fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch {
|
fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch(self.origin_server_ts) }
|
||||||
MilliSecondsSinceUnixEpoch(self.origin_server_ts)
|
|
||||||
}
|
|
||||||
|
|
||||||
fn state_key(&self) -> Option<&str> {
|
fn state_key(&self) -> Option<&str> { self.state_key.as_deref() }
|
||||||
self.state_key.as_deref()
|
|
||||||
}
|
|
||||||
|
|
||||||
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
|
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.prev_events.iter()) }
|
||||||
Box::new(self.prev_events.iter())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
|
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.auth_events.iter()) }
|
||||||
Box::new(self.auth_events.iter())
|
|
||||||
}
|
|
||||||
|
|
||||||
fn redacts(&self) -> Option<&Self::Id> {
|
fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() }
|
||||||
self.redacts.as_ref()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// These impl's allow us to dedup state snapshots when resolving state
|
// These impl's allow us to dedup state snapshots when resolving state
|
||||||
// for incoming events (federation/send/{txn}).
|
// for incoming events (federation/send/{txn}).
|
||||||
impl Eq for PduEvent {}
|
impl Eq for PduEvent {}
|
||||||
impl PartialEq for PduEvent {
|
impl PartialEq for PduEvent {
|
||||||
fn eq(&self, other: &Self) -> bool {
|
fn eq(&self, other: &Self) -> bool { self.event_id == other.event_id }
|
||||||
self.event_id == other.event_id
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
impl PartialOrd for PduEvent {
|
impl PartialOrd for PduEvent {
|
||||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
|
||||||
Some(self.cmp(other))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
impl Ord for PduEvent {
|
impl Ord for PduEvent {
|
||||||
fn cmp(&self, other: &Self) -> Ordering {
|
fn cmp(&self, other: &Self) -> Ordering { self.event_id.cmp(&other.event_id) }
|
||||||
self.event_id.cmp(&other.event_id)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Generates a correct eventId for the incoming pdu.
|
/// Generates a correct eventId for the incoming pdu.
|
||||||
///
|
///
|
||||||
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, CanonicalJsonValue>`.
|
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String,
|
||||||
|
/// CanonicalJsonValue>`.
|
||||||
pub(crate) fn gen_event_id_canonical_json(
|
pub(crate) fn gen_event_id_canonical_json(
|
||||||
pdu: &RawJsonValue,
|
pdu: &RawJsonValue, room_version_id: &RoomVersionId,
|
||||||
room_version_id: &RoomVersionId,
|
|
||||||
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> {
|
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> {
|
||||||
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
|
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
|
||||||
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
|
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
|
||||||
|
@ -389,8 +352,7 @@ pub(crate) fn gen_event_id_canonical_json(
|
||||||
let event_id = format!(
|
let event_id = format!(
|
||||||
"${}",
|
"${}",
|
||||||
// Anything higher than version3 behaves the same
|
// Anything higher than version3 behaves the same
|
||||||
ruma::signatures::reference_hash(&value, room_version_id)
|
ruma::signatures::reference_hash(&value, room_version_id).expect("ruma can calculate reference hashes")
|
||||||
.expect("ruma can calculate reference hashes")
|
|
||||||
)
|
)
|
||||||
.try_into()
|
.try_into()
|
||||||
.expect("ruma's reference hashes are valid event ids");
|
.expect("ruma's reference hashes are valid event ids");
|
||||||
|
|
|
@ -1,9 +1,10 @@
|
||||||
use crate::Result;
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::client::push::{set_pusher, Pusher},
|
api::client::push::{set_pusher, Pusher},
|
||||||
UserId,
|
UserId,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
pub trait Data: Send + Sync {
|
pub trait Data: Send + Sync {
|
||||||
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>;
|
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>;
|
||||||
|
|
||||||
|
@ -11,6 +12,5 @@ pub trait Data: Send + Sync {
|
||||||
|
|
||||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>;
|
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>;
|
||||||
|
|
||||||
fn get_pushkeys<'a>(&'a self, sender: &UserId)
|
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a>;
|
||||||
-> Box<dyn Iterator<Item = Result<String>> + 'a>;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,8 @@
|
||||||
mod data;
|
mod data;
|
||||||
pub use data::Data;
|
use std::{fmt::Debug, mem};
|
||||||
use ruma::{events::AnySyncTimelineEvent, push::PushConditionPowerLevelsCtx};
|
|
||||||
|
|
||||||
use crate::{services, Error, PduEvent, Result};
|
|
||||||
use bytes::BytesMut;
|
use bytes::BytesMut;
|
||||||
|
pub use data::Data;
|
||||||
use ruma::{
|
use ruma::{
|
||||||
api::{
|
api::{
|
||||||
client::push::{set_pusher, Pusher, PusherKind},
|
client::push::{set_pusher, Pusher, PusherKind},
|
||||||
|
@ -13,15 +12,17 @@ use ruma::{
|
||||||
},
|
},
|
||||||
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||||
},
|
},
|
||||||
events::{room::power_levels::RoomPowerLevelsEventContent, StateEventType, TimelineEventType},
|
events::{
|
||||||
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent, StateEventType, TimelineEventType,
|
||||||
|
},
|
||||||
|
push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
||||||
serde::Raw,
|
serde::Raw,
|
||||||
uint, RoomId, UInt, UserId,
|
uint, RoomId, UInt, UserId,
|
||||||
};
|
};
|
||||||
|
|
||||||
use std::{fmt::Debug, mem};
|
|
||||||
use tracing::{info, warn};
|
use tracing::{info, warn};
|
||||||
|
|
||||||
|
use crate::{services, Error, PduEvent, Result};
|
||||||
|
|
||||||
pub struct Service {
|
pub struct Service {
|
||||||
pub db: &'static dyn Data,
|
pub db: &'static dyn Data,
|
||||||
}
|
}
|
||||||
|
@ -35,31 +36,21 @@ impl Service {
|
||||||
self.db.get_pusher(sender, pushkey)
|
self.db.get_pusher(sender, pushkey)
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
|
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { self.db.get_pushers(sender) }
|
||||||
self.db.get_pushers(sender)
|
|
||||||
}
|
|
||||||
|
|
||||||
pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> {
|
pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> {
|
||||||
self.db.get_pushkeys(sender)
|
self.db.get_pushkeys(sender)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self, destination, request))]
|
#[tracing::instrument(skip(self, destination, request))]
|
||||||
pub async fn send_request<T>(
|
pub async fn send_request<T>(&self, destination: &str, request: T) -> Result<T::IncomingResponse>
|
||||||
&self,
|
|
||||||
destination: &str,
|
|
||||||
request: T,
|
|
||||||
) -> Result<T::IncomingResponse>
|
|
||||||
where
|
where
|
||||||
T: OutgoingRequest + Debug,
|
T: OutgoingRequest + Debug,
|
||||||
{
|
{
|
||||||
let destination = destination.replace(services().globals.notification_push_path(), "");
|
let destination = destination.replace(services().globals.notification_push_path(), "");
|
||||||
|
|
||||||
let http_request = request
|
let http_request = request
|
||||||
.try_into_http_request::<BytesMut>(
|
.try_into_http_request::<BytesMut>(&destination, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_0])
|
||||||
&destination,
|
|
||||||
SendAccessToken::IfRequired(""),
|
|
||||||
&[MatrixVersion::V1_0],
|
|
||||||
)
|
|
||||||
.map_err(|e| {
|
.map_err(|e| {
|
||||||
warn!("Failed to find destination {}: {}", destination, e);
|
warn!("Failed to find destination {}: {}", destination, e);
|
||||||
Error::BadServerResponse("Invalid destination")
|
Error::BadServerResponse("Invalid destination")
|
||||||
|
@ -72,24 +63,16 @@ impl Service {
|
||||||
//*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
|
//*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
|
||||||
|
|
||||||
let url = reqwest_request.url().clone();
|
let url = reqwest_request.url().clone();
|
||||||
let response = services()
|
let response = services().globals.default_client().execute(reqwest_request).await;
|
||||||
.globals
|
|
||||||
.default_client()
|
|
||||||
.execute(reqwest_request)
|
|
||||||
.await;
|
|
||||||
|
|
||||||
match response {
|
match response {
|
||||||
Ok(mut response) => {
|
Ok(mut response) => {
|
||||||
// reqwest::Response -> http::Response conversion
|
// reqwest::Response -> http::Response conversion
|
||||||
let status = response.status();
|
let status = response.status();
|
||||||
let mut http_response_builder = http::Response::builder()
|
let mut http_response_builder = http::Response::builder().status(status).version(response.version());
|
||||||
.status(status)
|
|
||||||
.version(response.version());
|
|
||||||
mem::swap(
|
mem::swap(
|
||||||
response.headers_mut(),
|
response.headers_mut(),
|
||||||
http_response_builder
|
http_response_builder.headers_mut().expect("http::response::Builder is usable"),
|
||||||
.headers_mut()
|
|
||||||
.expect("http::response::Builder is usable"),
|
|
||||||
);
|
);
|
||||||
|
|
||||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||||
|
@ -108,33 +91,23 @@ impl Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
let response = T::IncomingResponse::try_from_http_response(
|
let response = T::IncomingResponse::try_from_http_response(
|
||||||
http_response_builder
|
http_response_builder.body(body).expect("reqwest body is valid http body"),
|
||||||
.body(body)
|
|
||||||
.expect("reqwest body is valid http body"),
|
|
||||||
);
|
);
|
||||||
response.map_err(|_| {
|
response.map_err(|_| {
|
||||||
info!(
|
info!("Push gateway returned invalid response bytes {}\n{}", destination, url);
|
||||||
"Push gateway returned invalid response bytes {}\n{}",
|
|
||||||
destination, url
|
|
||||||
);
|
|
||||||
Error::BadServerResponse("Push gateway returned bad response.")
|
Error::BadServerResponse("Push gateway returned bad response.")
|
||||||
})
|
})
|
||||||
}
|
},
|
||||||
Err(e) => {
|
Err(e) => {
|
||||||
warn!("Could not send request to pusher {}: {}", destination, e);
|
warn!("Could not send request to pusher {}: {}", destination, e);
|
||||||
Err(e.into())
|
Err(e.into())
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))]
|
#[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))]
|
||||||
pub async fn send_push_notice(
|
pub async fn send_push_notice(
|
||||||
&self,
|
&self, user: &UserId, unread: UInt, pusher: &Pusher, ruleset: Ruleset, pdu: &PduEvent,
|
||||||
user: &UserId,
|
|
||||||
unread: UInt,
|
|
||||||
pusher: &Pusher,
|
|
||||||
ruleset: Ruleset,
|
|
||||||
pdu: &PduEvent,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
let mut notify = None;
|
let mut notify = None;
|
||||||
let mut tweaks = Vec::new();
|
let mut tweaks = Vec::new();
|
||||||
|
@ -150,19 +123,13 @@ impl Service {
|
||||||
.transpose()?
|
.transpose()?
|
||||||
.unwrap_or_default();
|
.unwrap_or_default();
|
||||||
|
|
||||||
for action in self.get_actions(
|
for action in self.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)? {
|
||||||
user,
|
|
||||||
&ruleset,
|
|
||||||
&power_levels,
|
|
||||||
&pdu.to_sync_room_event(),
|
|
||||||
&pdu.room_id,
|
|
||||||
)? {
|
|
||||||
let n = match action {
|
let n = match action {
|
||||||
Action::Notify => true,
|
Action::Notify => true,
|
||||||
Action::SetTweak(tweak) => {
|
Action::SetTweak(tweak) => {
|
||||||
tweaks.push(tweak.clone());
|
tweaks.push(tweak.clone());
|
||||||
continue;
|
continue;
|
||||||
}
|
},
|
||||||
_ => false,
|
_ => false,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -185,12 +152,8 @@ impl Service {
|
||||||
|
|
||||||
#[tracing::instrument(skip(self, user, ruleset, pdu))]
|
#[tracing::instrument(skip(self, user, ruleset, pdu))]
|
||||||
pub fn get_actions<'a>(
|
pub fn get_actions<'a>(
|
||||||
&self,
|
&self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent,
|
||||||
user: &UserId,
|
pdu: &Raw<AnySyncTimelineEvent>, room_id: &RoomId,
|
||||||
ruleset: &'a Ruleset,
|
|
||||||
power_levels: &RoomPowerLevelsEventContent,
|
|
||||||
pdu: &Raw<AnySyncTimelineEvent>,
|
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Result<&'a [Action]> {
|
) -> Result<&'a [Action]> {
|
||||||
let power_levels = PushConditionPowerLevelsCtx {
|
let power_levels = PushConditionPowerLevelsCtx {
|
||||||
users: power_levels.users.clone(),
|
users: power_levels.users.clone(),
|
||||||
|
@ -200,18 +163,9 @@ impl Service {
|
||||||
|
|
||||||
let ctx = PushConditionRoomCtx {
|
let ctx = PushConditionRoomCtx {
|
||||||
room_id: room_id.to_owned(),
|
room_id: room_id.to_owned(),
|
||||||
member_count: UInt::from(
|
member_count: UInt::from(services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(1) as u32),
|
||||||
services()
|
|
||||||
.rooms
|
|
||||||
.state_cache
|
|
||||||
.room_joined_count(room_id)?
|
|
||||||
.unwrap_or(1) as u32,
|
|
||||||
),
|
|
||||||
user_id: user.to_owned(),
|
user_id: user.to_owned(),
|
||||||
user_display_name: services()
|
user_display_name: services().users.displayname(user)?.unwrap_or_else(|| user.localpart().to_owned()),
|
||||||
.users
|
|
||||||
.displayname(user)?
|
|
||||||
.unwrap_or_else(|| user.localpart().to_owned()),
|
|
||||||
power_levels: Some(power_levels),
|
power_levels: Some(power_levels),
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -219,19 +173,14 @@ impl Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self, unread, pusher, tweaks, event))]
|
#[tracing::instrument(skip(self, unread, pusher, tweaks, event))]
|
||||||
async fn send_notice(
|
async fn send_notice(&self, unread: UInt, pusher: &Pusher, tweaks: Vec<Tweak>, event: &PduEvent) -> Result<()> {
|
||||||
&self,
|
|
||||||
unread: UInt,
|
|
||||||
pusher: &Pusher,
|
|
||||||
tweaks: Vec<Tweak>,
|
|
||||||
event: &PduEvent,
|
|
||||||
) -> Result<()> {
|
|
||||||
// TODO: email
|
// TODO: email
|
||||||
match &pusher.kind {
|
match &pusher.kind {
|
||||||
PusherKind::Http(http) => {
|
PusherKind::Http(http) => {
|
||||||
// TODO:
|
// TODO:
|
||||||
// Two problems with this
|
// Two problems with this
|
||||||
// 1. if "event_id_only" is the only format kind it seems we should never add more info
|
// 1. if "event_id_only" is the only format kind it seems we should never add
|
||||||
|
// more info
|
||||||
// 2. can pusher/devices have conflicting formats
|
// 2. can pusher/devices have conflicting formats
|
||||||
let event_id_only = http.format == Some(PushFormat::EventIdOnly);
|
let event_id_only = http.format == Some(PushFormat::EventIdOnly);
|
||||||
|
|
||||||
|
@ -254,36 +203,31 @@ impl Service {
|
||||||
notifi.counts = NotificationCounts::new(unread, uint!(0));
|
notifi.counts = NotificationCounts::new(unread, uint!(0));
|
||||||
|
|
||||||
if event.kind == TimelineEventType::RoomEncrypted
|
if event.kind == TimelineEventType::RoomEncrypted
|
||||||
|| tweaks
|
|| tweaks.iter().any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
|
||||||
.iter()
|
|
||||||
.any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
|
|
||||||
{
|
{
|
||||||
notifi.prio = NotificationPriority::High;
|
notifi.prio = NotificationPriority::High;
|
||||||
}
|
}
|
||||||
|
|
||||||
if event_id_only {
|
if event_id_only {
|
||||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
|
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?;
|
||||||
.await?;
|
|
||||||
} else {
|
} else {
|
||||||
notifi.sender = Some(event.sender.clone());
|
notifi.sender = Some(event.sender.clone());
|
||||||
notifi.event_type = Some(event.kind.clone());
|
notifi.event_type = Some(event.kind.clone());
|
||||||
notifi.content = serde_json::value::to_raw_value(&event.content).ok();
|
notifi.content = serde_json::value::to_raw_value(&event.content).ok();
|
||||||
|
|
||||||
if event.kind == TimelineEventType::RoomMember {
|
if event.kind == TimelineEventType::RoomMember {
|
||||||
notifi.user_is_target =
|
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
|
||||||
event.state_key.as_deref() == Some(event.sender.as_str());
|
|
||||||
}
|
}
|
||||||
|
|
||||||
notifi.sender_display_name = services().users.displayname(&event.sender)?;
|
notifi.sender_display_name = services().users.displayname(&event.sender)?;
|
||||||
|
|
||||||
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?;
|
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?;
|
||||||
|
|
||||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
|
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?;
|
||||||
.await?;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
},
|
||||||
// TODO: Handle email
|
// TODO: Handle email
|
||||||
//PusherKind::Email(_) => Ok(()),
|
//PusherKind::Email(_) => Ok(()),
|
||||||
_ => Ok(()),
|
_ => Ok(()),
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::Result;
|
|
||||||
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
|
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
pub trait Data: Send + Sync {
|
pub trait Data: Send + Sync {
|
||||||
/// Creates or updates the alias to the given room id.
|
/// Creates or updates the alias to the given room id.
|
||||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>;
|
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>;
|
||||||
|
@ -13,12 +14,9 @@ pub trait Data: Send + Sync {
|
||||||
|
|
||||||
/// Returns all local aliases that point to the given room
|
/// Returns all local aliases that point to the given room
|
||||||
fn local_aliases_for_room<'a>(
|
fn local_aliases_for_room<'a>(
|
||||||
&'a self,
|
&'a self, room_id: &RoomId,
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a>;
|
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a>;
|
||||||
|
|
||||||
/// Returns all local aliases on the server
|
/// Returns all local aliases on the server
|
||||||
fn all_local_aliases<'a>(
|
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a>;
|
||||||
&'a self,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a>;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +1,9 @@
|
||||||
mod data;
|
mod data;
|
||||||
|
|
||||||
pub use data::Data;
|
pub use data::Data;
|
||||||
|
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
|
||||||
|
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
|
|
||||||
|
|
||||||
pub struct Service {
|
pub struct Service {
|
||||||
pub db: &'static dyn Data,
|
pub db: &'static dyn Data,
|
||||||
|
@ -11,14 +11,10 @@ pub struct Service {
|
||||||
|
|
||||||
impl Service {
|
impl Service {
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { self.db.set_alias(alias, room_id) }
|
||||||
self.db.set_alias(alias, room_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { self.db.remove_alias(alias) }
|
||||||
self.db.remove_alias(alias)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
|
pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
|
||||||
|
@ -27,16 +23,13 @@ impl Service {
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn local_aliases_for_room<'a>(
|
pub fn local_aliases_for_room<'a>(
|
||||||
&'a self,
|
&'a self, room_id: &RoomId,
|
||||||
room_id: &RoomId,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
||||||
self.db.local_aliases_for_room(room_id)
|
self.db.local_aliases_for_room(room_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn all_local_aliases<'a>(
|
pub fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||||
&'a self,
|
|
||||||
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
|
||||||
self.db.all_local_aliases()
|
self.db.all_local_aliases()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,11 +1,8 @@
|
||||||
use crate::Result;
|
|
||||||
use std::{collections::HashSet, sync::Arc};
|
use std::{collections::HashSet, sync::Arc};
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
pub trait Data: Send + Sync {
|
pub trait Data: Send + Sync {
|
||||||
fn get_cached_eventid_authchain(
|
fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result<Option<Arc<HashSet<u64>>>>;
|
||||||
&self,
|
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()>;
|
||||||
shorteventid: &[u64],
|
|
||||||
) -> Result<Option<Arc<HashSet<u64>>>>;
|
|
||||||
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>)
|
|
||||||
-> Result<()>;
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -26,9 +26,7 @@ impl Service {
|
||||||
|
|
||||||
#[tracing::instrument(skip(self, starting_events))]
|
#[tracing::instrument(skip(self, starting_events))]
|
||||||
pub async fn get_auth_chain<'a>(
|
pub async fn get_auth_chain<'a>(
|
||||||
&self,
|
&self, room_id: &RoomId, starting_events: Vec<Arc<EventId>>,
|
||||||
room_id: &RoomId,
|
|
||||||
starting_events: Vec<Arc<EventId>>,
|
|
||||||
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
|
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
|
||||||
const NUM_BUCKETS: usize = 50;
|
const NUM_BUCKETS: usize = 50;
|
||||||
|
|
||||||
|
@ -55,11 +53,7 @@ impl Service {
|
||||||
}
|
}
|
||||||
|
|
||||||
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
|
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
|
||||||
if let Some(cached) = services()
|
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? {
|
||||||
.rooms
|
|
||||||
.auth_chain
|
|
||||||
.get_cached_eventid_authchain(&chunk_key)?
|
|
||||||
{
|
|
||||||
hits += 1;
|
hits += 1;
|
||||||
full_auth_chain.extend(cached.iter().copied());
|
full_auth_chain.extend(cached.iter().copied());
|
||||||
continue;
|
continue;
|
||||||
|
@ -71,20 +65,13 @@ impl Service {
|
||||||
let mut misses2 = 0;
|
let mut misses2 = 0;
|
||||||
let mut i = 0;
|
let mut i = 0;
|
||||||
for (sevent_id, event_id) in chunk {
|
for (sevent_id, event_id) in chunk {
|
||||||
if let Some(cached) = services()
|
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? {
|
||||||
.rooms
|
|
||||||
.auth_chain
|
|
||||||
.get_cached_eventid_authchain(&[sevent_id])?
|
|
||||||
{
|
|
||||||
hits2 += 1;
|
hits2 += 1;
|
||||||
chunk_cache.extend(cached.iter().copied());
|
chunk_cache.extend(cached.iter().copied());
|
||||||
} else {
|
} else {
|
||||||
misses2 += 1;
|
misses2 += 1;
|
||||||
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
|
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
|
||||||
services()
|
services().rooms.auth_chain.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
|
||||||
.rooms
|
|
||||||
.auth_chain
|
|
||||||
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
|
|
||||||
debug!(
|
debug!(
|
||||||
event_id = ?event_id,
|
event_id = ?event_id,
|
||||||
chain_length = ?auth_chain.len(),
|
chain_length = ?auth_chain.len(),
|
||||||
|
@ -105,10 +92,7 @@ impl Service {
|
||||||
"Chunk missed",
|
"Chunk missed",
|
||||||
);
|
);
|
||||||
let chunk_cache = Arc::new(chunk_cache);
|
let chunk_cache = Arc::new(chunk_cache);
|
||||||
services()
|
services().rooms.auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
|
||||||
.rooms
|
|
||||||
.auth_chain
|
|
||||||
.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
|
|
||||||
full_auth_chain.extend(chunk_cache.iter());
|
full_auth_chain.extend(chunk_cache.iter());
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -119,9 +103,7 @@ impl Service {
|
||||||
"Auth chain stats",
|
"Auth chain stats",
|
||||||
);
|
);
|
||||||
|
|
||||||
Ok(full_auth_chain
|
Ok(full_auth_chain.into_iter().filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
|
||||||
.into_iter()
|
|
||||||
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tracing::instrument(skip(self, event_id))]
|
#[tracing::instrument(skip(self, event_id))]
|
||||||
|
@ -136,23 +118,20 @@ impl Service {
|
||||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
|
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
|
||||||
}
|
}
|
||||||
for auth_event in &pdu.auth_events {
|
for auth_event in &pdu.auth_events {
|
||||||
let sauthevent = services()
|
let sauthevent = services().rooms.short.get_or_create_shorteventid(auth_event)?;
|
||||||
.rooms
|
|
||||||
.short
|
|
||||||
.get_or_create_shorteventid(auth_event)?;
|
|
||||||
|
|
||||||
if !found.contains(&sauthevent) {
|
if !found.contains(&sauthevent) {
|
||||||
found.insert(sauthevent);
|
found.insert(sauthevent);
|
||||||
todo.push(auth_event.clone());
|
todo.push(auth_event.clone());
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
Ok(None) => {
|
Ok(None) => {
|
||||||
warn!(?event_id, "Could not find pdu mentioned in auth events");
|
warn!(?event_id, "Could not find pdu mentioned in auth events");
|
||||||
}
|
},
|
||||||
Err(error) => {
|
Err(error) => {
|
||||||
error!(?event_id, ?error, "Could not load event in auth chain");
|
error!(?event_id, ?error, "Could not load event in auth chain");
|
||||||
}
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
use crate::Result;
|
|
||||||
use ruma::{OwnedRoomId, RoomId};
|
use ruma::{OwnedRoomId, RoomId};
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
pub trait Data: Send + Sync {
|
pub trait Data: Send + Sync {
|
||||||
/// Adds the room to the public room directory
|
/// Adds the room to the public room directory
|
||||||
fn set_public(&self, room_id: &RoomId) -> Result<()>;
|
fn set_public(&self, room_id: &RoomId) -> Result<()>;
|
||||||
|
|
|
@ -11,22 +11,14 @@ pub struct Service {
|
||||||
|
|
||||||
impl Service {
|
impl Service {
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn set_public(&self, room_id: &RoomId) -> Result<()> {
|
pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) }
|
||||||
self.db.set_public(room_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
|
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) }
|
||||||
self.db.set_not_public(room_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { self.db.is_public_room(room_id) }
|
||||||
self.db.is_public_room(room_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
#[tracing::instrument(skip(self))]
|
#[tracing::instrument(skip(self))]
|
||||||
pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ {
|
pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.public_rooms() }
|
||||||
self.db.public_rooms()
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,33 +1,27 @@
|
||||||
|
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId};
|
||||||
|
|
||||||
use crate::Result;
|
use crate::Result;
|
||||||
use ruma::{
|
|
||||||
events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId,
|
|
||||||
};
|
|
||||||
|
|
||||||
pub trait Data: Send + Sync {
|
pub trait Data: Send + Sync {
|
||||||
/// Returns the latest presence event for the given user in the given room.
|
/// Returns the latest presence event for the given user in the given room.
|
||||||
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>>;
|
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>>;
|
||||||
|
|
||||||
/// Pings the presence of the given user in the given room, setting the specified state.
|
/// Pings the presence of the given user in the given room, setting the
|
||||||
|
/// specified state.
|
||||||
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()>;
|
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()>;
|
||||||
|
|
||||||
/// Adds a presence event which will be saved until a new event replaces it.
|
/// Adds a presence event which will be saved until a new event replaces it.
|
||||||
fn set_presence(
|
fn set_presence(
|
||||||
&self,
|
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
|
||||||
room_id: &RoomId,
|
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||||
user_id: &UserId,
|
|
||||||
presence_state: PresenceState,
|
|
||||||
currently_active: Option<bool>,
|
|
||||||
last_active_ago: Option<UInt>,
|
|
||||||
status_msg: Option<String>,
|
|
||||||
) -> Result<()>;
|
) -> Result<()>;
|
||||||
|
|
||||||
/// Removes the presence record for the given user from the database.
|
/// Removes the presence record for the given user from the database.
|
||||||
fn remove_presence(&self, user_id: &UserId) -> Result<()>;
|
fn remove_presence(&self, user_id: &UserId) -> Result<()>;
|
||||||
|
|
||||||
/// Returns the most recent presence updates that happened after the event with id `since`.
|
/// Returns the most recent presence updates that happened after the event
|
||||||
|
/// with id `since`.
|
||||||
fn presence_since<'a>(
|
fn presence_since<'a>(
|
||||||
&'a self,
|
&'a self, room_id: &RoomId, since: u64,
|
||||||
room_id: &RoomId,
|
|
||||||
since: u64,
|
|
||||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a>;
|
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a>;
|
||||||
}
|
}
|
||||||
|
|
|
@ -15,7 +15,8 @@ use tracing::debug;
|
||||||
|
|
||||||
use crate::{services, utils, Error, Result};
|
use crate::{services, utils, Error, Result};
|
||||||
|
|
||||||
/// Represents data required to be kept in order to implement the presence specification.
|
/// Represents data required to be kept in order to implement the presence
|
||||||
|
/// specification.
|
||||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||||
pub struct Presence {
|
pub struct Presence {
|
||||||
pub state: PresenceState,
|
pub state: PresenceState,
|
||||||
|
@ -27,11 +28,7 @@ pub struct Presence {
|
||||||
|
|
||||||
impl Presence {
|
impl Presence {
|
||||||
pub fn new(
|
pub fn new(
|
||||||
state: PresenceState,
|
state: PresenceState, currently_active: bool, last_active_ts: u64, last_count: u64, status_msg: Option<String>,
|
||||||
currently_active: bool,
|
|
||||||
last_active_ts: u64,
|
|
||||||
last_count: u64,
|
|
||||||
status_msg: Option<String>,
|
|
||||||
) -> Self {
|
) -> Self {
|
||||||
Self {
|
Self {
|
||||||
state,
|
state,
|
||||||
|
@ -43,13 +40,11 @@ impl Presence {
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
|
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
|
||||||
serde_json::from_slice(bytes)
|
serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database"))
|
||||||
.map_err(|_| Error::bad_database("Invalid presence data in database"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pub fn to_json_bytes(&self) -> Result<Vec<u8>> {
|
pub fn to_json_bytes(&self) -> Result<Vec<u8>> {
|
||||||
serde_json::to_vec(self)
|
serde_json::to_vec(self).map_err(|_| Error::bad_database("Could not serialize Presence to JSON"))
|
||||||
.map_err(|_| Error::bad_database("Could not serialize Presence to JSON"))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Creates a PresenceEvent from available data.
|
/// Creates a PresenceEvent from available data.
|
||||||
|
@ -58,9 +53,7 @@ impl Presence {
|
||||||
let last_active_ago = if self.currently_active {
|
let last_active_ago = if self.currently_active {
|
||||||
None
|
None
|
||||||
} else {
|
} else {
|
||||||
Some(UInt::new_saturating(
|
Some(UInt::new_saturating(now.saturating_sub(self.last_active_ts)))
|
||||||
now.saturating_sub(self.last_active_ts),
|
|
||||||
))
|
|
||||||
};
|
};
|
||||||
|
|
||||||
Ok(PresenceEvent {
|
Ok(PresenceEvent {
|
||||||
|
@ -83,49 +76,31 @@ pub struct Service {
|
||||||
|
|
||||||
impl Service {
|
impl Service {
|
||||||
/// Returns the latest presence event for the given user in the given room.
|
/// Returns the latest presence event for the given user in the given room.
|
||||||
pub fn get_presence(
|
pub fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>> {
|
||||||
&self,
|
|
||||||
room_id: &RoomId,
|
|
||||||
user_id: &UserId,
|
|
||||||
) -> Result<Option<PresenceEvent>> {
|
|
||||||
self.db.get_presence(room_id, user_id)
|
self.db.get_presence(room_id, user_id)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Pings the presence of the given user in the given room, setting the specified state.
|
/// Pings the presence of the given user in the given room, setting the
|
||||||
|
/// specified state.
|
||||||
pub fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> {
|
pub fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> {
|
||||||
self.db.ping_presence(user_id, new_state)
|
self.db.ping_presence(user_id, new_state)
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Adds a presence event which will be saved until a new event replaces it.
|
/// Adds a presence event which will be saved until a new event replaces it.
|
||||||
pub fn set_presence(
|
pub fn set_presence(
|
||||||
&self,
|
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
|
||||||
room_id: &RoomId,
|
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||||
user_id: &UserId,
|
|
||||||
presence_state: PresenceState,
|
|
||||||
currently_active: Option<bool>,
|
|
||||||
last_active_ago: Option<UInt>,
|
|
||||||
status_msg: Option<String>,
|
|
||||||
) -> Result<()> {
|
) -> Result<()> {
|
||||||
self.db.set_presence(
|
self.db.set_presence(room_id, user_id, presence_state, currently_active, last_active_ago, status_msg)
|
||||||
room_id,
|
|
||||||
user_id,
|
|
||||||
presence_state,
|
|
||||||
currently_active,
|
|
||||||
last_active_ago,
|
|
||||||
status_msg,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
/// Removes the presence record for the given user from the database.
|
/// Removes the presence record for the given user from the database.
|
||||||
pub fn remove_presence(&self, user_id: &UserId) -> Result<()> {
|
pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) }
|
||||||
self.db.remove_presence(user_id)
|
|
||||||
}
|
|
||||||
|
|
||||||
/// Returns the most recent presence updates that happened after the event with id `since`.
|
/// Returns the most recent presence updates that happened after the event
|
||||||
|
/// with id `since`.
|
||||||
pub fn presence_since(
|
pub fn presence_since(
|
||||||
&self,
|
&self, room_id: &RoomId, since: u64,
|
||||||
room_id: &RoomId,
|
|
||||||
since: u64,
|
|
||||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)>> {
|
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)>> {
|
||||||
self.db.presence_since(room_id, since)
|
self.db.presence_since(room_id, since)
|
||||||
}
|
}
|
||||||
|
@ -167,11 +142,7 @@ fn process_presence_timer(user_id: OwnedUserId) -> Result<()> {
|
||||||
let mut status_msg = None;
|
let mut status_msg = None;
|
||||||
|
|
||||||
for room_id in services().rooms.state_cache.rooms_joined(&user_id) {
|
for room_id in services().rooms.state_cache.rooms_joined(&user_id) {
|
||||||
let presence_event = services()
|
let presence_event = services().rooms.edus.presence.get_presence(&room_id?, &user_id)?;
|
||||||
.rooms
|
|
||||||
.edus
|
|
||||||
.presence
|
|
||||||
.get_presence(&room_id?, &user_id)?;
|
|
||||||
|
|
||||||
if let Some(presence_event) = presence_event {
|
if let Some(presence_event) = presence_event {
|
||||||
presence_state = presence_event.content.presence;
|
presence_state = presence_event.content.presence;
|
||||||
|
@ -183,12 +154,8 @@ fn process_presence_timer(user_id: OwnedUserId) -> Result<()> {
|
||||||
}
|
}
|
||||||
|
|
||||||
let new_state = match (&presence_state, last_active_ago.map(u64::from)) {
|
let new_state = match (&presence_state, last_active_ago.map(u64::from)) {
|
||||||
(PresenceState::Online, Some(ago)) if ago >= idle_timeout => {
|
(PresenceState::Online, Some(ago)) if ago >= idle_timeout => Some(PresenceState::Unavailable),
|
||||||
Some(PresenceState::Unavailable)
|
(PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => Some(PresenceState::Offline),
|
||||||
}
|
|
||||||
(PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => {
|
|
||||||
Some(PresenceState::Offline)
|
|
||||||
}
|
|
||||||
_ => None,
|
_ => None,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
|
|
@ -1,25 +1,21 @@
|
||||||
use crate::Result;
|
|
||||||
use ruma::{
|
use ruma::{
|
||||||
events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent},
|
events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent},
|
||||||
serde::Raw,
|
serde::Raw,
|
||||||
OwnedUserId, RoomId, UserId,
|
OwnedUserId, RoomId, UserId,
|
||||||
};
|
};
|
||||||
|
|
||||||
|
use crate::Result;
|
||||||
|
|
||||||
type AnySyncEphemeralRoomEventIter<'a> =
|
type AnySyncEphemeralRoomEventIter<'a> =
|
||||||
Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>;
|
Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>;
|
||||||
|
|
||||||
pub trait Data: Send + Sync {
|
pub trait Data: Send + Sync {
|
||||||
/// Replaces the previous read receipt.
|
/// Replaces the previous read receipt.
|
||||||
fn readreceipt_update(
|
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()>;
|
||||||
&self,
|
|
||||||
user_id: &UserId,
|
|
||||||
room_id: &RoomId,
|
|
||||||
event: ReceiptEvent,
|
|
||||||
) -> Result<()>;
|
|
||||||
|
|
||||||
/// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
|
/// Returns an iterator over the most recent read_receipts in a room that
|
||||||
fn readreceipts_since(&self, room_id: &RoomId, since: u64)
|
/// happened after the event with id `since`.
|
||||||
-> AnySyncEphemeralRoomEventIter<'_>;
|
fn readreceipts_since(&self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'_>;
|
||||||
|
|
||||||
/// Sets a private read marker at `count`.
|
/// Sets a private read marker at `count`.
|
||||||
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>;
|
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>;
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue