add rustfmt.toml, format entire codebase
Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
parent
9fd521f041
commit
f419c64aca
144 changed files with 25573 additions and 31053 deletions
27
rustfmt.toml
Normal file
27
rustfmt.toml
Normal file
|
@ -0,0 +1,27 @@
|
|||
edition = "2021"
|
||||
|
||||
condense_wildcard_suffixes = true
|
||||
format_code_in_doc_comments = true
|
||||
format_macro_bodies = true
|
||||
format_macro_matchers = true
|
||||
format_strings = true
|
||||
hex_literal_case = "Upper"
|
||||
max_width = 120
|
||||
tab_spaces = 4
|
||||
array_width = 80
|
||||
comment_width = 80
|
||||
wrap_comments = true
|
||||
fn_params_layout = "Compressed"
|
||||
fn_call_width = 80
|
||||
fn_single_line = true
|
||||
hard_tabs = true
|
||||
match_block_trailing_comma = true
|
||||
imports_granularity = "Crate"
|
||||
normalize_comments = false
|
||||
reorder_impl_items = true
|
||||
reorder_imports = true
|
||||
group_imports = "StdExternalCrate"
|
||||
newline_style = "Unix"
|
||||
use_field_init_shorthand = true
|
||||
use_small_heuristics = "Off"
|
||||
use_try_shorthand = true
|
|
@ -1,18 +1,16 @@
|
|||
use crate::{services, utils, Error, Result};
|
||||
use bytes::BytesMut;
|
||||
use ruma::api::{
|
||||
appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
};
|
||||
use std::{fmt::Debug, mem, time::Duration};
|
||||
|
||||
use bytes::BytesMut;
|
||||
use ruma::api::{appservice::Registration, IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::{services, utils, Error, Result};
|
||||
|
||||
/// Sends a request to an appservice
|
||||
///
|
||||
/// Only returns None if there is no url specified in the appservice registration file
|
||||
pub(crate) async fn send_request<T>(
|
||||
registration: Registration,
|
||||
request: T,
|
||||
) -> Option<Result<T::IncomingResponse>>
|
||||
/// Only returns None if there is no url specified in the appservice
|
||||
/// registration file
|
||||
pub(crate) async fn send_request<T>(registration: Registration, request: T) -> Option<Result<T::IncomingResponse>>
|
||||
where
|
||||
T: OutgoingRequest + Debug,
|
||||
{
|
||||
|
@ -40,25 +38,16 @@ where
|
|||
"?"
|
||||
};
|
||||
|
||||
parts.path_and_query = Some(
|
||||
(old_path_and_query + symbol + "access_token=" + hs_token)
|
||||
.parse()
|
||||
.unwrap(),
|
||||
);
|
||||
parts.path_and_query = Some((old_path_and_query + symbol + "access_token=" + hs_token).parse().unwrap());
|
||||
*http_request.uri_mut() = parts.try_into().expect("our manipulation is always valid");
|
||||
|
||||
let mut reqwest_request = reqwest::Request::try_from(http_request)
|
||||
.expect("all http requests are valid reqwest requests");
|
||||
let mut reqwest_request =
|
||||
reqwest::Request::try_from(http_request).expect("all http requests are valid reqwest requests");
|
||||
|
||||
*reqwest_request.timeout_mut() = Some(Duration::from_secs(120));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let mut response = match services()
|
||||
.globals
|
||||
.default_client()
|
||||
.execute(reqwest_request)
|
||||
.await
|
||||
{
|
||||
let mut response = match services().globals.default_client().execute(reqwest_request).await {
|
||||
Ok(r) => r,
|
||||
Err(e) => {
|
||||
warn!(
|
||||
|
@ -66,19 +55,15 @@ where
|
|||
registration.id, destination, e
|
||||
);
|
||||
return Some(Err(e.into()));
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder()
|
||||
.status(status)
|
||||
.version(response.version());
|
||||
let mut http_response_builder = http::Response::builder().status(status).version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder
|
||||
.headers_mut()
|
||||
.expect("http::response::Builder is usable"),
|
||||
http_response_builder.headers_mut().expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
|
@ -97,15 +82,10 @@ where
|
|||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder
|
||||
.body(body)
|
||||
.expect("reqwest body is valid http body"),
|
||||
http_response_builder.body(body).expect("reqwest body is valid http body"),
|
||||
);
|
||||
Some(response.map_err(|_| {
|
||||
warn!(
|
||||
"Appservice returned invalid response bytes {}\n{}",
|
||||
destination, url
|
||||
);
|
||||
warn!("Appservice returned invalid response bytes {}\n{}", destination, url);
|
||||
Error::BadServerResponse("Server returned bad response.")
|
||||
}))
|
||||
} else {
|
||||
|
|
|
@ -1,11 +1,10 @@
|
|||
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{api::client_server, services, utils, Error, Result, Ruma};
|
||||
use register::RegistrationKind;
|
||||
use ruma::{
|
||||
api::client::{
|
||||
account::{
|
||||
change_password, deactivate, get_3pids, get_username_availability, register,
|
||||
request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn,
|
||||
whoami, ThirdPartyIdRemovalStatus,
|
||||
request_3pid_management_token_via_email, request_3pid_management_token_via_msisdn, whoami,
|
||||
ThirdPartyIdRemovalStatus,
|
||||
},
|
||||
error::ErrorKind,
|
||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||
|
@ -15,7 +14,8 @@ use ruma::{
|
|||
};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use register::RegistrationKind;
|
||||
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{api::client_server, services, utils, Error, Result, Ruma};
|
||||
|
||||
const RANDOM_USER_ID_LENGTH: usize = 10;
|
||||
|
||||
|
@ -28,130 +28,109 @@ const RANDOM_USER_ID_LENGTH: usize = 10;
|
|||
/// - The server name of the user id matches this server
|
||||
/// - No user or appservice on this server already claimed this username
|
||||
///
|
||||
/// Note: This will not reserve the username, so the username might become invalid when trying to register
|
||||
/// Note: This will not reserve the username, so the username might become
|
||||
/// invalid when trying to register
|
||||
pub async fn get_register_available_route(
|
||||
body: Ruma<get_username_availability::v3::Request>,
|
||||
) -> Result<get_username_availability::v3::Response> {
|
||||
// Validate user id
|
||||
let user_id = UserId::parse_with_server_name(
|
||||
body.username.to_lowercase(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name())
|
||||
.ok()
|
||||
.filter(|user_id| {
|
||||
!user_id.is_historical() && user_id.server_name() == services().globals.server_name()
|
||||
})
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidUsername,
|
||||
"Username is invalid.",
|
||||
))?;
|
||||
.filter(|user_id| !user_id.is_historical() && user_id.server_name() == services().globals.server_name())
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
|
||||
|
||||
// Check if username is creative enough
|
||||
if services().users.exists(&user_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UserInUse,
|
||||
"Desired user ID is already taken.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.globals
|
||||
.forbidden_usernames()
|
||||
.is_match(user_id.localpart())
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Username is forbidden.",
|
||||
));
|
||||
if services().globals.forbidden_usernames().is_match(user_id.localpart()) {
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden."));
|
||||
}
|
||||
|
||||
// TODO add check for appservice namespaces
|
||||
|
||||
// If no if check is true we have an username that's available to be used.
|
||||
Ok(get_username_availability::v3::Response { available: true })
|
||||
Ok(get_username_availability::v3::Response {
|
||||
available: true,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `POST /_matrix/client/v3/register`
|
||||
///
|
||||
/// Register an account on this homeserver.
|
||||
///
|
||||
/// You can use [`GET /_matrix/client/v3/register/available`](fn.get_register_available_route.html)
|
||||
/// to check if the user id is valid and available.
|
||||
/// You can use [`GET
|
||||
/// /_matrix/client/v3/register/available`](fn.get_register_available_route.
|
||||
/// html) to check if the user id is valid and available.
|
||||
///
|
||||
/// - Only works if registration is enabled
|
||||
/// - If type is guest: ignores all parameters except initial_device_display_name
|
||||
/// - If type is guest: ignores all parameters except
|
||||
/// initial_device_display_name
|
||||
/// - If sender is not appservice: Requires UIAA (but we only use a dummy stage)
|
||||
/// - If type is not guest and no username is given: Always fails after UIAA check
|
||||
/// - If type is not guest and no username is given: Always fails after UIAA
|
||||
/// check
|
||||
/// - Creates a new account and populates it with default account data
|
||||
/// - If `inhibit_login` is false: Creates a device and returns device id and access_token
|
||||
/// - If `inhibit_login` is false: Creates a device and returns device id and
|
||||
/// access_token
|
||||
pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<register::v3::Response> {
|
||||
if !services().globals.allow_registration() && !body.from_appservice {
|
||||
info!("Registration disabled and request not from known appservice, rejecting registration attempt for username {:?}", body.username);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Registration has been disabled.",
|
||||
));
|
||||
info!(
|
||||
"Registration disabled and request not from known appservice, rejecting registration attempt for username \
|
||||
{:?}",
|
||||
body.username
|
||||
);
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Registration has been disabled."));
|
||||
}
|
||||
|
||||
let is_guest = body.kind == RegistrationKind::Guest;
|
||||
|
||||
if is_guest
|
||||
&& (!services().globals.allow_guest_registration()
|
||||
|| (services().globals.allow_registration()
|
||||
&& services().globals.config.registration_token.is_some()))
|
||||
|| (services().globals.allow_registration() && services().globals.config.registration_token.is_some()))
|
||||
{
|
||||
info!("Guest registration disabled / registration enabled with token configured, rejecting guest registration, initial device name: {:?}", body.initial_device_display_name);
|
||||
info!(
|
||||
"Guest registration disabled / registration enabled with token configured, rejecting guest registration, \
|
||||
initial device name: {:?}",
|
||||
body.initial_device_display_name
|
||||
);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::GuestAccessForbidden,
|
||||
"Guest registration is disabled.",
|
||||
));
|
||||
}
|
||||
|
||||
// forbid guests from registering if there is not a real admin user yet. give generic user error.
|
||||
// forbid guests from registering if there is not a real admin user yet. give
|
||||
// generic user error.
|
||||
if is_guest && services().users.count()? < 2 {
|
||||
warn!("Guest account attempted to register before a real admin user has been registered, rejecting registration. Guest's initial device name: {:?}", body.initial_device_display_name);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Registration temporarily disabled.",
|
||||
));
|
||||
warn!(
|
||||
"Guest account attempted to register before a real admin user has been registered, rejecting \
|
||||
registration. Guest's initial device name: {:?}",
|
||||
body.initial_device_display_name
|
||||
);
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Registration temporarily disabled."));
|
||||
}
|
||||
|
||||
let user_id = match (&body.username, is_guest) {
|
||||
(Some(username), false) => {
|
||||
let proposed_user_id = UserId::parse_with_server_name(
|
||||
username.to_lowercase(),
|
||||
services().globals.server_name(),
|
||||
)
|
||||
let proposed_user_id =
|
||||
UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name())
|
||||
.ok()
|
||||
.filter(|user_id| {
|
||||
!user_id.is_historical()
|
||||
&& user_id.server_name() == services().globals.server_name()
|
||||
!user_id.is_historical() && user_id.server_name() == services().globals.server_name()
|
||||
})
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidUsername,
|
||||
"Username is invalid.",
|
||||
))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
|
||||
|
||||
if services().users.exists(&proposed_user_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UserInUse,
|
||||
"Desired user ID is already taken.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::UserInUse, "Desired user ID is already taken."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.globals
|
||||
.forbidden_usernames()
|
||||
.is_match(proposed_user_id.localpart())
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Username is forbidden.",
|
||||
));
|
||||
if services().globals.forbidden_usernames().is_match(proposed_user_id.localpart()) {
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Username is forbidden."));
|
||||
}
|
||||
|
||||
proposed_user_id
|
||||
}
|
||||
},
|
||||
_ => loop {
|
||||
let proposed_user_id = UserId::parse_with_server_name(
|
||||
utils::random_string(RANDOM_USER_ID_LENGTH).to_lowercase(),
|
||||
|
@ -196,8 +175,7 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
|||
if !skip_auth {
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(
|
||||
&UserId::parse_with_server_name("", services().globals.server_name())
|
||||
.expect("we know this is valid"),
|
||||
&UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"),
|
||||
"".into(),
|
||||
auth,
|
||||
&uiaainfo,
|
||||
|
@ -209,8 +187,7 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
|||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services().uiaa.create(
|
||||
&UserId::parse_with_server_name("", services().globals.server_name())
|
||||
.expect("we know this is valid"),
|
||||
&UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid"),
|
||||
"".into(),
|
||||
&uiaainfo,
|
||||
&json,
|
||||
|
@ -233,15 +210,13 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
|||
// Default to pretty displayname
|
||||
let mut displayname = user_id.localpart().to_owned();
|
||||
|
||||
// If `new_user_displayname_suffix` is set, registration will push whatever content is set to the user's display name with a space before it
|
||||
// If `new_user_displayname_suffix` is set, registration will push whatever
|
||||
// content is set to the user's display name with a space before it
|
||||
if !services().globals.new_user_displayname_suffix().is_empty() {
|
||||
displayname.push_str(&(" ".to_owned() + services().globals.new_user_displayname_suffix()));
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(&user_id, Some(displayname.clone()))
|
||||
.await?;
|
||||
services().users.set_displayname(&user_id, Some(displayname.clone())).await?;
|
||||
|
||||
// Initial account data
|
||||
services().account_data.update(
|
||||
|
@ -279,41 +254,29 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
|||
let token = utils::random_string(TOKEN_LENGTH);
|
||||
|
||||
// Create device for this account
|
||||
services().users.create_device(
|
||||
&user_id,
|
||||
&device_id,
|
||||
&token,
|
||||
body.initial_device_display_name.clone(),
|
||||
)?;
|
||||
services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
|
||||
|
||||
info!("New user \"{}\" registered on this server.", user_id);
|
||||
|
||||
// log in conduit admin channel if a non-guest user registered
|
||||
if !body.from_appservice && !is_guest {
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"New user \"{user_id}\" registered on this server."
|
||||
)));
|
||||
}
|
||||
|
||||
// log in conduit admin channel if a guest registered
|
||||
if !body.from_appservice && is_guest {
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"Guest user \"{user_id}\" with device display name `{:?}` registered on this server.",
|
||||
body.initial_device_display_name
|
||||
)));
|
||||
}
|
||||
|
||||
// If this is the first real user, grant them admin privileges except for guest users
|
||||
// Note: the server user, @conduit:servername, is generated first
|
||||
// If this is the first real user, grant them admin privileges except for guest
|
||||
// users Note: the server user, @conduit:servername, is generated first
|
||||
if services().users.count()? == 2 && !is_guest {
|
||||
services()
|
||||
.admin
|
||||
.make_user_admin(&user_id, displayname)
|
||||
.await?;
|
||||
services().admin.make_user_admin(&user_id, displayname).await?;
|
||||
|
||||
warn!("Granting {} admin privileges as the first user", user_id);
|
||||
}
|
||||
|
@ -333,17 +296,18 @@ pub async fn register_route(body: Ruma<register::v3::Request>) -> Result<registe
|
|||
///
|
||||
/// - Requires UIAA to verify user password
|
||||
/// - Changes the password of the sender user
|
||||
/// - The password hash is calculated using argon2 with 32 character salt, the plain password is
|
||||
/// - The password hash is calculated using argon2 with 32 character salt, the
|
||||
/// plain password is
|
||||
/// not saved
|
||||
///
|
||||
/// If logout_devices is true it does the following for each device except the sender device:
|
||||
/// If logout_devices is true it does the following for each device except the
|
||||
/// sender device:
|
||||
/// - Invalidates access token
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn change_password_route(
|
||||
body: Ruma<change_password::v3::Request>,
|
||||
) -> Result<change_password::v3::Response> {
|
||||
pub async fn change_password_route(body: Ruma<change_password::v3::Request>) -> Result<change_password::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
|
@ -358,27 +322,20 @@ pub async fn change_password_route(
|
|||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_password(sender_user, Some(&body.new_password))?;
|
||||
services().users.set_password(sender_user, Some(&body.new_password))?;
|
||||
|
||||
if body.logout_devices {
|
||||
// Logout all devices except the current one
|
||||
|
@ -393,9 +350,7 @@ pub async fn change_password_route(
|
|||
}
|
||||
|
||||
info!("User {} changed their password.", sender_user);
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"User {sender_user} changed their password."
|
||||
)));
|
||||
|
||||
|
@ -424,13 +379,12 @@ pub async fn whoami_route(body: Ruma<whoami::v3::Request>) -> Result<whoami::v3:
|
|||
///
|
||||
/// - Leaves all rooms and rejects all invitations
|
||||
/// - Invalidates all access tokens
|
||||
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes all device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets all to-device events
|
||||
/// - Triggers device list updates
|
||||
/// - Removes ability to log in again
|
||||
pub async fn deactivate_route(
|
||||
body: Ruma<deactivate::v3::Request>,
|
||||
) -> Result<deactivate::v3::Response> {
|
||||
pub async fn deactivate_route(body: Ruma<deactivate::v3::Request>) -> Result<deactivate::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
|
@ -445,19 +399,14 @@ pub async fn deactivate_route(
|
|||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
|
@ -470,9 +419,7 @@ pub async fn deactivate_route(
|
|||
services().users.deactivate_account(sender_user)?;
|
||||
|
||||
info!("User {} deactivated their account.", sender_user);
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
services().admin.send_message(RoomMessageEventContent::notice_plain(format!(
|
||||
"User {sender_user} deactivated their account."
|
||||
)));
|
||||
|
||||
|
@ -486,9 +433,7 @@ pub async fn deactivate_route(
|
|||
/// Get a list of third party identifiers associated with this account.
|
||||
///
|
||||
/// - Currently always returns empty list
|
||||
pub async fn third_party_route(
|
||||
body: Ruma<get_3pids::v3::Request>,
|
||||
) -> Result<get_3pids::v3::Response> {
|
||||
pub async fn third_party_route(body: Ruma<get_3pids::v3::Request>) -> Result<get_3pids::v3::Response> {
|
||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
Ok(get_3pids::v3::Response::new(Vec::new()))
|
||||
|
@ -496,9 +441,11 @@ pub async fn third_party_route(
|
|||
|
||||
/// # `POST /_matrix/client/v3/account/3pid/email/requestToken`
|
||||
///
|
||||
/// "This API should be used to request validation tokens when adding an email address to an account"
|
||||
/// "This API should be used to request validation tokens when adding an email
|
||||
/// address to an account"
|
||||
///
|
||||
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option.
|
||||
/// - 403 signals that The homeserver does not allow the third party identifier
|
||||
/// as a contact option.
|
||||
pub async fn request_3pid_management_token_via_email_route(
|
||||
_body: Ruma<request_3pid_management_token_via_email::v3::Request>,
|
||||
) -> Result<request_3pid_management_token_via_email::v3::Response> {
|
||||
|
@ -510,9 +457,11 @@ pub async fn request_3pid_management_token_via_email_route(
|
|||
|
||||
/// # `POST /_matrix/client/v3/account/3pid/msisdn/requestToken`
|
||||
///
|
||||
/// "This API should be used to request validation tokens when adding an phone number to an account"
|
||||
/// "This API should be used to request validation tokens when adding an phone
|
||||
/// number to an account"
|
||||
///
|
||||
/// - 403 signals that The homeserver does not allow the third party identifier as a contact option.
|
||||
/// - 403 signals that The homeserver does not allow the third party identifier
|
||||
/// as a contact option.
|
||||
pub async fn request_3pid_management_token_via_msisdn_route(
|
||||
_body: Ruma<request_3pid_management_token_via_msisdn::v3::Request>,
|
||||
) -> Result<request_3pid_management_token_via_msisdn::v3::Response> {
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
use rand::seq::SliceRandom;
|
||||
use regex::Regex;
|
||||
use ruma::{
|
||||
|
@ -13,45 +12,25 @@ use ruma::{
|
|||
OwnedRoomAliasId, OwnedServerName,
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/v3/directory/room/{roomAlias}`
|
||||
///
|
||||
/// Creates a new room alias on this server.
|
||||
pub async fn create_alias_route(
|
||||
body: Ruma<create_alias::v3::Request>,
|
||||
) -> Result<create_alias::v3::Response> {
|
||||
pub async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> Result<create_alias::v3::Response> {
|
||||
if body.room_alias.server_name() != services().globals.server_name() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Alias is from another server.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.globals
|
||||
.forbidden_room_names()
|
||||
.is_match(body.room_alias.alias())
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Room alias is forbidden.",
|
||||
));
|
||||
if services().globals.forbidden_room_names().is_match(body.room_alias.alias()) {
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias is forbidden."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&body.room_alias)?
|
||||
.is_some()
|
||||
{
|
||||
if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_some() {
|
||||
return Err(Error::Conflict("Alias already exists."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.set_alias(&body.room_alias, &body.room_id)
|
||||
.is_err()
|
||||
{
|
||||
if services().rooms.alias.set_alias(&body.room_alias, &body.room_id).is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
|
||||
|
@ -67,34 +46,16 @@ pub async fn create_alias_route(
|
|||
///
|
||||
/// - TODO: additional access control checks
|
||||
/// - TODO: Update canonical alias event
|
||||
pub async fn delete_alias_route(
|
||||
body: Ruma<delete_alias::v3::Request>,
|
||||
) -> Result<delete_alias::v3::Response> {
|
||||
pub async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> Result<delete_alias::v3::Response> {
|
||||
if body.room_alias.server_name() != services().globals.server_name() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Alias is from another server.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&body.room_alias)?
|
||||
.is_none()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Alias does not exist.",
|
||||
));
|
||||
if services().rooms.alias.resolve_local_alias(&body.room_alias)?.is_none() {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
|
||||
}
|
||||
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.remove_alias(&body.room_alias)
|
||||
.is_err()
|
||||
{
|
||||
if services().rooms.alias.remove_alias(&body.room_alias).is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid room alias. Alias must be in the form of '#localpart:server_name'",
|
||||
|
@ -109,15 +70,11 @@ pub async fn delete_alias_route(
|
|||
/// # `GET /_matrix/client/v3/directory/room/{roomAlias}`
|
||||
///
|
||||
/// Resolve an alias locally or over federation.
|
||||
pub async fn get_alias_route(
|
||||
body: Ruma<get_alias::v3::Request>,
|
||||
) -> Result<get_alias::v3::Response> {
|
||||
pub async fn get_alias_route(body: Ruma<get_alias::v3::Request>) -> Result<get_alias::v3::Response> {
|
||||
get_alias_helper(body.body.room_alias).await
|
||||
}
|
||||
|
||||
pub(crate) async fn get_alias_helper(
|
||||
room_alias: OwnedRoomAliasId,
|
||||
) -> Result<get_alias::v3::Response> {
|
||||
pub(crate) async fn get_alias_helper(room_alias: OwnedRoomAliasId) -> Result<get_alias::v3::Response> {
|
||||
if room_alias.server_name() != services().globals.server_name() {
|
||||
let response = services()
|
||||
.sending
|
||||
|
@ -134,20 +91,13 @@ pub(crate) async fn get_alias_helper(
|
|||
let mut servers = response.servers;
|
||||
|
||||
// find active servers in room state cache to suggest
|
||||
for extra_servers in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_servers(&room_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(std::result::Result::ok) {
|
||||
servers.push(extra_servers);
|
||||
}
|
||||
|
||||
// insert our server as the very first choice if in list
|
||||
if let Some(server_index) = servers
|
||||
.clone()
|
||||
.into_iter()
|
||||
.position(|server| server == services().globals.server_name())
|
||||
if let Some(server_index) =
|
||||
servers.clone().into_iter().position(|server| server == services().globals.server_name())
|
||||
{
|
||||
servers.remove(server_index);
|
||||
servers.insert(0, services().globals.server_name().to_owned());
|
||||
|
@ -174,9 +124,7 @@ pub(crate) async fn get_alias_helper(
|
|||
.filter_map(|alias| Regex::new(alias.regex.as_str()).ok())
|
||||
.collect::<Vec<_>>();
|
||||
|
||||
if aliases
|
||||
.iter()
|
||||
.any(|aliases| aliases.is_match(room_alias.as_str()))
|
||||
if aliases.iter().any(|aliases| aliases.is_match(room_alias.as_str()))
|
||||
&& if let Some(opt_result) = services()
|
||||
.sending
|
||||
.send_appservice_request(
|
||||
|
@ -190,50 +138,35 @@ pub(crate) async fn get_alias_helper(
|
|||
opt_result.is_ok()
|
||||
} else {
|
||||
false
|
||||
}
|
||||
{
|
||||
} {
|
||||
room_id = Some(
|
||||
services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&room_alias)?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_config("Appservice lied to us. Room does not exist.")
|
||||
})?,
|
||||
.ok_or_else(|| Error::bad_config("Appservice lied to us. Room does not exist."))?,
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let room_id = match room_id {
|
||||
Some(room_id) => room_id,
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Room with alias not found.",
|
||||
))
|
||||
}
|
||||
None => return Err(Error::BadRequest(ErrorKind::NotFound, "Room with alias not found.")),
|
||||
};
|
||||
|
||||
let mut servers: Vec<OwnedServerName> = Vec::new();
|
||||
|
||||
// find active servers in room state cache to suggest
|
||||
for extra_servers in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_servers(&room_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
for extra_servers in services().rooms.state_cache.room_servers(&room_id).filter_map(std::result::Result::ok) {
|
||||
servers.push(extra_servers);
|
||||
}
|
||||
|
||||
// insert our server as the very first choice if in list
|
||||
if let Some(server_index) = servers
|
||||
.clone()
|
||||
.into_iter()
|
||||
.position(|server| server == services().globals.server_name())
|
||||
if let Some(server_index) =
|
||||
servers.clone().into_iter().position(|server| server == services().globals.server_name())
|
||||
{
|
||||
servers.remove(server_index);
|
||||
servers.insert(0, services().globals.server_name().to_owned());
|
||||
|
|
|
@ -1,15 +1,15 @@
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
backup::{
|
||||
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session,
|
||||
create_backup_version, delete_backup_keys, delete_backup_keys_for_room,
|
||||
delete_backup_keys_for_session, delete_backup_version, get_backup_info, get_backup_keys,
|
||||
get_backup_keys_for_room, get_backup_keys_for_session, get_latest_backup_info,
|
||||
update_backup_version,
|
||||
add_backup_keys, add_backup_keys_for_room, add_backup_keys_for_session, create_backup_version,
|
||||
delete_backup_keys, delete_backup_keys_for_room, delete_backup_keys_for_session, delete_backup_version,
|
||||
get_backup_info, get_backup_keys, get_backup_keys_for_room, get_backup_keys_for_session,
|
||||
get_latest_backup_info, update_backup_version,
|
||||
},
|
||||
error::ErrorKind,
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/r0/room_keys/version`
|
||||
///
|
||||
/// Creates a new backup.
|
||||
|
@ -17,23 +17,22 @@ pub async fn create_backup_version_route(
|
|||
body: Ruma<create_backup_version::v3::Request>,
|
||||
) -> Result<create_backup_version::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let version = services()
|
||||
.key_backups
|
||||
.create_backup(sender_user, &body.algorithm)?;
|
||||
let version = services().key_backups.create_backup(sender_user, &body.algorithm)?;
|
||||
|
||||
Ok(create_backup_version::v3::Response { version })
|
||||
Ok(create_backup_version::v3::Response {
|
||||
version,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/room_keys/version/{version}`
|
||||
///
|
||||
/// Update information about an existing backup. Only `auth_data` can be modified.
|
||||
/// Update information about an existing backup. Only `auth_data` can be
|
||||
/// modified.
|
||||
pub async fn update_backup_version_route(
|
||||
body: Ruma<update_backup_version::v3::Request>,
|
||||
) -> Result<update_backup_version::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
services()
|
||||
.key_backups
|
||||
.update_backup(sender_user, &body.version, &body.algorithm)?;
|
||||
services().key_backups.update_backup(sender_user, &body.version, &body.algorithm)?;
|
||||
|
||||
Ok(update_backup_version::v3::Response {})
|
||||
}
|
||||
|
@ -49,10 +48,7 @@ pub async fn get_latest_backup_info_route(
|
|||
let (version, algorithm) = services()
|
||||
.key_backups
|
||||
.get_latest_backup(sender_user)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Key backup does not exist.",
|
||||
))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
|
||||
|
||||
Ok(get_latest_backup_info::v3::Response {
|
||||
algorithm,
|
||||
|
@ -65,27 +61,17 @@ pub async fn get_latest_backup_info_route(
|
|||
/// # `GET /_matrix/client/r0/room_keys/version`
|
||||
///
|
||||
/// Get information about an existing backup.
|
||||
pub async fn get_backup_info_route(
|
||||
body: Ruma<get_backup_info::v3::Request>,
|
||||
) -> Result<get_backup_info::v3::Response> {
|
||||
pub async fn get_backup_info_route(body: Ruma<get_backup_info::v3::Request>) -> Result<get_backup_info::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let algorithm = services()
|
||||
.key_backups
|
||||
.get_backup(sender_user, &body.version)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Key backup does not exist.",
|
||||
))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Key backup does not exist."))?;
|
||||
|
||||
Ok(get_backup_info::v3::Response {
|
||||
algorithm,
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
version: body.version.clone(),
|
||||
})
|
||||
}
|
||||
|
@ -94,15 +80,14 @@ pub async fn get_backup_info_route(
|
|||
///
|
||||
/// Delete an existing key backup.
|
||||
///
|
||||
/// - Deletes both information about the backup, as well as all key data related to the backup
|
||||
/// - Deletes both information about the backup, as well as all key data related
|
||||
/// to the backup
|
||||
pub async fn delete_backup_version_route(
|
||||
body: Ruma<delete_backup_version::v3::Request>,
|
||||
) -> Result<delete_backup_version::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.key_backups
|
||||
.delete_backup(sender_user, &body.version)?;
|
||||
services().key_backups.delete_backup(sender_user, &body.version)?;
|
||||
|
||||
Ok(delete_backup_version::v3::Response {})
|
||||
}
|
||||
|
@ -111,20 +96,14 @@ pub async fn delete_backup_version_route(
|
|||
///
|
||||
/// Add the received backup keys to the database.
|
||||
///
|
||||
/// - Only manipulating the most recently created version of the backup is allowed
|
||||
/// - Only manipulating the most recently created version of the backup is
|
||||
/// allowed
|
||||
/// - Adds the keys to the backup
|
||||
/// - Returns the new number of keys in this backup and the etag
|
||||
pub async fn add_backup_keys_route(
|
||||
body: Ruma<add_backup_keys::v3::Request>,
|
||||
) -> Result<add_backup_keys::v3::Response> {
|
||||
pub async fn add_backup_keys_route(body: Ruma<add_backup_keys::v3::Request>) -> Result<add_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if Some(&body.version)
|
||||
!= services()
|
||||
.key_backups
|
||||
.get_latest_backup_version(sender_user)?
|
||||
.as_ref()
|
||||
{
|
||||
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"You may only manipulate the most recently created version of the backup.",
|
||||
|
@ -133,24 +112,13 @@ pub async fn add_backup_keys_route(
|
|||
|
||||
for (room_id, room) in &body.rooms {
|
||||
for (session_id, key_data) in &room.sessions {
|
||||
services().key_backups.add_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
room_id,
|
||||
session_id,
|
||||
key_data,
|
||||
)?;
|
||||
services().key_backups.add_key(sender_user, &body.version, room_id, session_id, key_data)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(add_backup_keys::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -158,7 +126,8 @@ pub async fn add_backup_keys_route(
|
|||
///
|
||||
/// Add the received backup keys to the database.
|
||||
///
|
||||
/// - Only manipulating the most recently created version of the backup is allowed
|
||||
/// - Only manipulating the most recently created version of the backup is
|
||||
/// allowed
|
||||
/// - Adds the keys to the backup
|
||||
/// - Returns the new number of keys in this backup and the etag
|
||||
pub async fn add_backup_keys_for_room_route(
|
||||
|
@ -166,12 +135,7 @@ pub async fn add_backup_keys_for_room_route(
|
|||
) -> Result<add_backup_keys_for_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if Some(&body.version)
|
||||
!= services()
|
||||
.key_backups
|
||||
.get_latest_backup_version(sender_user)?
|
||||
.as_ref()
|
||||
{
|
||||
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"You may only manipulate the most recently created version of the backup.",
|
||||
|
@ -179,23 +143,12 @@ pub async fn add_backup_keys_for_room_route(
|
|||
}
|
||||
|
||||
for (session_id, key_data) in &body.sessions {
|
||||
services().key_backups.add_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
&body.room_id,
|
||||
session_id,
|
||||
key_data,
|
||||
)?;
|
||||
services().key_backups.add_key(sender_user, &body.version, &body.room_id, session_id, key_data)?;
|
||||
}
|
||||
|
||||
Ok(add_backup_keys_for_room::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -203,7 +156,8 @@ pub async fn add_backup_keys_for_room_route(
|
|||
///
|
||||
/// Add the received backup key to the database.
|
||||
///
|
||||
/// - Only manipulating the most recently created version of the backup is allowed
|
||||
/// - Only manipulating the most recently created version of the backup is
|
||||
/// allowed
|
||||
/// - Adds the keys to the backup
|
||||
/// - Returns the new number of keys in this backup and the etag
|
||||
pub async fn add_backup_keys_for_session_route(
|
||||
|
@ -211,48 +165,32 @@ pub async fn add_backup_keys_for_session_route(
|
|||
) -> Result<add_backup_keys_for_session::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if Some(&body.version)
|
||||
!= services()
|
||||
.key_backups
|
||||
.get_latest_backup_version(sender_user)?
|
||||
.as_ref()
|
||||
{
|
||||
if Some(&body.version) != services().key_backups.get_latest_backup_version(sender_user)?.as_ref() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"You may only manipulate the most recently created version of the backup.",
|
||||
));
|
||||
}
|
||||
|
||||
services().key_backups.add_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
&body.room_id,
|
||||
&body.session_id,
|
||||
&body.session_data,
|
||||
)?;
|
||||
services().key_backups.add_key(sender_user, &body.version, &body.room_id, &body.session_id, &body.session_data)?;
|
||||
|
||||
Ok(add_backup_keys_for_session::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/room_keys/keys`
|
||||
///
|
||||
/// Retrieves all keys from the backup.
|
||||
pub async fn get_backup_keys_route(
|
||||
body: Ruma<get_backup_keys::v3::Request>,
|
||||
) -> Result<get_backup_keys::v3::Response> {
|
||||
pub async fn get_backup_keys_route(body: Ruma<get_backup_keys::v3::Request>) -> Result<get_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let rooms = services().key_backups.get_all(sender_user, &body.version)?;
|
||||
|
||||
Ok(get_backup_keys::v3::Response { rooms })
|
||||
Ok(get_backup_keys::v3::Response {
|
||||
rooms,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}`
|
||||
|
@ -263,11 +201,11 @@ pub async fn get_backup_keys_for_room_route(
|
|||
) -> Result<get_backup_keys_for_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let sessions = services()
|
||||
.key_backups
|
||||
.get_room(sender_user, &body.version, &body.room_id)?;
|
||||
let sessions = services().key_backups.get_room(sender_user, &body.version, &body.room_id)?;
|
||||
|
||||
Ok(get_backup_keys_for_room::v3::Response { sessions })
|
||||
Ok(get_backup_keys_for_room::v3::Response {
|
||||
sessions,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/room_keys/keys/{roomId}/{sessionId}`
|
||||
|
@ -278,15 +216,14 @@ pub async fn get_backup_keys_for_session_route(
|
|||
) -> Result<get_backup_keys_for_session::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let key_data = services()
|
||||
.key_backups
|
||||
.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Backup key not found for this user's session.",
|
||||
))?;
|
||||
let key_data =
|
||||
services().key_backups.get_session(sender_user, &body.version, &body.room_id, &body.session_id)?.ok_or(
|
||||
Error::BadRequest(ErrorKind::NotFound, "Backup key not found for this user's session."),
|
||||
)?;
|
||||
|
||||
Ok(get_backup_keys_for_session::v3::Response { key_data })
|
||||
Ok(get_backup_keys_for_session::v3::Response {
|
||||
key_data,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `DELETE /_matrix/client/r0/room_keys/keys`
|
||||
|
@ -297,18 +234,11 @@ pub async fn delete_backup_keys_route(
|
|||
) -> Result<delete_backup_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.key_backups
|
||||
.delete_all_keys(sender_user, &body.version)?;
|
||||
services().key_backups.delete_all_keys(sender_user, &body.version)?;
|
||||
|
||||
Ok(delete_backup_keys::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -320,18 +250,11 @@ pub async fn delete_backup_keys_for_room_route(
|
|||
) -> Result<delete_backup_keys_for_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.key_backups
|
||||
.delete_room_keys(sender_user, &body.version, &body.room_id)?;
|
||||
services().key_backups.delete_room_keys(sender_user, &body.version, &body.room_id)?;
|
||||
|
||||
Ok(delete_backup_keys_for_room::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -343,20 +266,10 @@ pub async fn delete_backup_keys_for_session_route(
|
|||
) -> Result<delete_backup_keys_for_session::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services().key_backups.delete_room_key(
|
||||
sender_user,
|
||||
&body.version,
|
||||
&body.room_id,
|
||||
&body.session_id,
|
||||
)?;
|
||||
services().key_backups.delete_room_key(sender_user, &body.version, &body.room_id, &body.session_id)?;
|
||||
|
||||
Ok(delete_backup_keys_for_session::v3::Response {
|
||||
count: (services()
|
||||
.key_backups
|
||||
.count_keys(sender_user, &body.version)? as u32)
|
||||
.into(),
|
||||
etag: services()
|
||||
.key_backups
|
||||
.get_etag(sender_user, &body.version)?,
|
||||
count: (services().key_backups.count_keys(sender_user, &body.version)? as u32).into(),
|
||||
etag: services().key_backups.get_etag(sender_user, &body.version)?,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,12 +1,15 @@
|
|||
use crate::{services, Result, Ruma};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::api::client::discovery::get_capabilities::{
|
||||
self, Capabilities, RoomVersionStability, RoomVersionsCapability,
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::{services, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/capabilities`
|
||||
///
|
||||
/// Get information on the supported feature set and other relevent capabilities of this server.
|
||||
/// Get information on the supported feature set and other relevent capabilities
|
||||
/// of this server.
|
||||
pub async fn get_capabilities_route(
|
||||
_body: Ruma<get_capabilities::v3::Request>,
|
||||
) -> Result<get_capabilities::v3::Response> {
|
||||
|
@ -24,5 +27,7 @@ pub async fn get_capabilities_route(
|
|||
available,
|
||||
};
|
||||
|
||||
Ok(get_capabilities::v3::Response { capabilities })
|
||||
Ok(get_capabilities::v3::Response {
|
||||
capabilities,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,10 +1,6 @@
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
config::{
|
||||
get_global_account_data, get_room_account_data, set_global_account_data,
|
||||
set_room_account_data,
|
||||
},
|
||||
config::{get_global_account_data, get_room_account_data, set_global_account_data, set_room_account_data},
|
||||
error::ErrorKind,
|
||||
},
|
||||
events::{AnyGlobalAccountDataEventContent, AnyRoomAccountDataEventContent},
|
||||
|
@ -13,6 +9,8 @@ use ruma::{
|
|||
use serde::Deserialize;
|
||||
use serde_json::{json, value::RawValue as RawJsonValue};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/user/{userId}/account_data/{type}`
|
||||
///
|
||||
/// Sets some account data for the sender user.
|
||||
|
@ -82,7 +80,9 @@ pub async fn get_global_account_data_route(
|
|||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
Ok(get_global_account_data::v3::Response { account_data })
|
||||
Ok(get_global_account_data::v3::Response {
|
||||
account_data,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/user/{userId}/rooms/{roomId}/account_data/{type}`
|
||||
|
@ -102,7 +102,9 @@ pub async fn get_room_account_data_route(
|
|||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
Ok(get_room_account_data::v3::Response { account_data })
|
||||
Ok(get_room_account_data::v3::Response {
|
||||
account_data,
|
||||
})
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
|
|
|
@ -1,20 +1,21 @@
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
use std::collections::HashSet;
|
||||
|
||||
use ruma::{
|
||||
api::client::{context::get_context, error::ErrorKind, filter::LazyLoadOptions},
|
||||
events::StateEventType,
|
||||
};
|
||||
use std::collections::HashSet;
|
||||
use tracing::error;
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/context`
|
||||
///
|
||||
/// Allows loading room history around an event.
|
||||
///
|
||||
/// - Only works if the user is joined (TODO: always allow, but only show events if the user was
|
||||
/// - Only works if the user is joined (TODO: always allow, but only show events
|
||||
/// if the user was
|
||||
/// joined, depending on history_visibility)
|
||||
pub async fn get_context_route(
|
||||
body: Ruma<get_context::v3::Request>,
|
||||
) -> Result<get_context::v3::Response> {
|
||||
pub async fn get_context_route(body: Ruma<get_context::v3::Request>) -> Result<get_context::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
|
@ -31,28 +32,17 @@ pub async fn get_context_route(
|
|||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Base event id not found.",
|
||||
))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event id not found."))?;
|
||||
|
||||
let base_event =
|
||||
services()
|
||||
let base_event = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Base event not found.",
|
||||
))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Base event not found."))?;
|
||||
|
||||
let room_id = base_event.room_id.clone();
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_event(sender_user, &room_id, &body.event_id)?
|
||||
{
|
||||
if !services().rooms.state_accessor.user_can_see_event(sender_user, &room_id, &body.event_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this event.",
|
||||
|
@ -101,15 +91,10 @@ pub async fn get_context_route(
|
|||
}
|
||||
}
|
||||
|
||||
let start_token = events_before
|
||||
.last()
|
||||
.map(|(count, _)| count.stringify())
|
||||
.unwrap_or_else(|| base_token.stringify());
|
||||
let start_token =
|
||||
events_before.last().map(|(count, _)| count.stringify()).unwrap_or_else(|| base_token.stringify());
|
||||
|
||||
let events_before: Vec<_> = events_before
|
||||
.into_iter()
|
||||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect();
|
||||
let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||
|
||||
let events_after: Vec<_> = services()
|
||||
.rooms
|
||||
|
@ -138,42 +123,25 @@ pub async fn get_context_route(
|
|||
}
|
||||
}
|
||||
|
||||
let shortstatehash = match services().rooms.state_accessor.pdu_shortstatehash(
|
||||
events_after
|
||||
.last()
|
||||
.map_or(&*body.event_id, |(_, e)| &*e.event_id),
|
||||
)? {
|
||||
Some(s) => s,
|
||||
None => services()
|
||||
.rooms
|
||||
.state
|
||||
.get_room_shortstatehash(&room_id)?
|
||||
.expect("All rooms have state"),
|
||||
};
|
||||
|
||||
let state_ids = services()
|
||||
let shortstatehash = match services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.state_full_ids(shortstatehash)
|
||||
.await?;
|
||||
.pdu_shortstatehash(events_after.last().map_or(&*body.event_id, |(_, e)| &*e.event_id))?
|
||||
{
|
||||
Some(s) => s,
|
||||
None => services().rooms.state.get_room_shortstatehash(&room_id)?.expect("All rooms have state"),
|
||||
};
|
||||
|
||||
let end_token = events_after
|
||||
.last()
|
||||
.map(|(count, _)| count.stringify())
|
||||
.unwrap_or_else(|| base_token.stringify());
|
||||
let state_ids = services().rooms.state_accessor.state_full_ids(shortstatehash).await?;
|
||||
|
||||
let events_after: Vec<_> = events_after
|
||||
.into_iter()
|
||||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect();
|
||||
let end_token = events_after.last().map(|(count, _)| count.stringify()).unwrap_or_else(|| base_token.stringify());
|
||||
|
||||
let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||
|
||||
let mut state = Vec::new();
|
||||
|
||||
for (shortstatekey, id) in state_ids {
|
||||
let (event_type, state_key) = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_statekey_from_short(shortstatekey)?;
|
||||
let (event_type, state_key) = services().rooms.short.get_statekey_from_short(shortstatekey)?;
|
||||
|
||||
if event_type != StateEventType::RoomMember {
|
||||
let pdu = match services().rooms.timeline.get_pdu(&id)? {
|
||||
|
@ -181,7 +149,7 @@ pub async fn get_context_route(
|
|||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
};
|
||||
state.push(pdu.to_state_event());
|
||||
} else if !lazy_load_enabled || lazy_loaded.contains(&state_key) {
|
||||
|
@ -190,7 +158,7 @@ pub async fn get_context_route(
|
|||
None => {
|
||||
error!("Pdu in state not found: {}", id);
|
||||
continue;
|
||||
}
|
||||
},
|
||||
};
|
||||
state.push(pdu.to_state_event());
|
||||
}
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
device::{self, delete_device, delete_devices, get_device, get_devices, update_device},
|
||||
error::ErrorKind,
|
||||
|
@ -6,13 +5,12 @@ use ruma::api::client::{
|
|||
};
|
||||
|
||||
use super::SESSION_ID_LENGTH;
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/devices`
|
||||
///
|
||||
/// Get metadata on all devices of the sender user.
|
||||
pub async fn get_devices_route(
|
||||
body: Ruma<get_devices::v3::Request>,
|
||||
) -> Result<get_devices::v3::Response> {
|
||||
pub async fn get_devices_route(body: Ruma<get_devices::v3::Request>) -> Result<get_devices::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let devices: Vec<device::Device> = services()
|
||||
|
@ -21,15 +19,15 @@ pub async fn get_devices_route(
|
|||
.filter_map(std::result::Result::ok) // Filter out buggy devices
|
||||
.collect();
|
||||
|
||||
Ok(get_devices::v3::Response { devices })
|
||||
Ok(get_devices::v3::Response {
|
||||
devices,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/devices/{deviceId}`
|
||||
///
|
||||
/// Get metadata on a single device of the sender user.
|
||||
pub async fn get_device_route(
|
||||
body: Ruma<get_device::v3::Request>,
|
||||
) -> Result<get_device::v3::Response> {
|
||||
pub async fn get_device_route(body: Ruma<get_device::v3::Request>) -> Result<get_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let device = services()
|
||||
|
@ -37,15 +35,15 @@ pub async fn get_device_route(
|
|||
.get_device_metadata(sender_user, &body.body.device_id)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Device not found."))?;
|
||||
|
||||
Ok(get_device::v3::Response { device })
|
||||
Ok(get_device::v3::Response {
|
||||
device,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/devices/{deviceId}`
|
||||
///
|
||||
/// Updates the metadata on a given device of the sender user.
|
||||
pub async fn update_device_route(
|
||||
body: Ruma<update_device::v3::Request>,
|
||||
) -> Result<update_device::v3::Response> {
|
||||
pub async fn update_device_route(body: Ruma<update_device::v3::Request>) -> Result<update_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut device = services()
|
||||
|
@ -55,9 +53,7 @@ pub async fn update_device_route(
|
|||
|
||||
device.display_name = body.display_name.clone();
|
||||
|
||||
services()
|
||||
.users
|
||||
.update_device_metadata(sender_user, &body.device_id, &device)?;
|
||||
services().users.update_device_metadata(sender_user, &body.device_id, &device)?;
|
||||
|
||||
Ok(update_device::v3::Response {})
|
||||
}
|
||||
|
@ -68,12 +64,11 @@ pub async fn update_device_route(
|
|||
///
|
||||
/// - Requires UIAA to verify user password
|
||||
/// - Invalidates access token
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn delete_device_route(
|
||||
body: Ruma<delete_device::v3::Request>,
|
||||
) -> Result<delete_device::v3::Response> {
|
||||
pub async fn delete_device_route(body: Ruma<delete_device::v3::Request>) -> Result<delete_device::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
|
@ -89,27 +84,20 @@ pub async fn delete_device_route(
|
|||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.remove_device(sender_user, &body.device_id)?;
|
||||
services().users.remove_device(sender_user, &body.device_id)?;
|
||||
|
||||
Ok(delete_device::v3::Response {})
|
||||
}
|
||||
|
@ -122,12 +110,11 @@ pub async fn delete_device_route(
|
|||
///
|
||||
/// For each device:
|
||||
/// - Invalidates access token
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn delete_devices_route(
|
||||
body: Ruma<delete_devices::v3::Request>,
|
||||
) -> Result<delete_devices::v3::Response> {
|
||||
pub async fn delete_devices_route(body: Ruma<delete_devices::v3::Request>) -> Result<delete_devices::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
|
@ -143,19 +130,14 @@ pub async fn delete_devices_route(
|
|||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
|
|
|
@ -1,11 +1,7 @@
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{
|
||||
directory::{
|
||||
get_public_rooms, get_public_rooms_filtered, get_room_visibility,
|
||||
set_room_visibility,
|
||||
},
|
||||
directory::{get_public_rooms, get_public_rooms_filtered, get_room_visibility, set_room_visibility},
|
||||
error::ErrorKind,
|
||||
room,
|
||||
},
|
||||
|
@ -28,6 +24,8 @@ use ruma::{
|
|||
};
|
||||
use tracing::{error, info, warn};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/v3/publicRooms`
|
||||
///
|
||||
/// Lists the public rooms on this server.
|
||||
|
@ -36,11 +34,7 @@ use tracing::{error, info, warn};
|
|||
pub async fn get_public_rooms_filtered_route(
|
||||
body: Ruma<get_public_rooms_filtered::v3::Request>,
|
||||
) -> Result<get_public_rooms_filtered::v3::Response> {
|
||||
if !services()
|
||||
.globals
|
||||
.config
|
||||
.allow_public_room_directory_without_auth
|
||||
{
|
||||
if !services().globals.config.allow_public_room_directory_without_auth {
|
||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
}
|
||||
|
||||
|
@ -62,11 +56,7 @@ pub async fn get_public_rooms_filtered_route(
|
|||
pub async fn get_public_rooms_route(
|
||||
body: Ruma<get_public_rooms::v3::Request>,
|
||||
) -> Result<get_public_rooms::v3::Response> {
|
||||
if !services()
|
||||
.globals
|
||||
.config
|
||||
.allow_public_room_directory_without_auth
|
||||
{
|
||||
if !services().globals.config.allow_public_room_directory_without_auth {
|
||||
let _sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
}
|
||||
|
||||
|
@ -106,14 +96,14 @@ pub async fn set_room_visibility_route(
|
|||
room::Visibility::Public => {
|
||||
services().rooms.directory.set_public(&body.room_id)?;
|
||||
info!("{} made {} public", sender_user, body.room_id);
|
||||
}
|
||||
},
|
||||
room::Visibility::Private => services().rooms.directory.set_not_public(&body.room_id)?,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Room visibility type is not supported.",
|
||||
));
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
Ok(set_room_visibility::v3::Response {})
|
||||
|
@ -140,15 +130,9 @@ pub async fn get_room_visibility_route(
|
|||
}
|
||||
|
||||
pub(crate) async fn get_public_rooms_filtered_helper(
|
||||
server: Option<&ServerName>,
|
||||
limit: Option<UInt>,
|
||||
since: Option<&str>,
|
||||
filter: &Filter,
|
||||
_network: &RoomNetwork,
|
||||
server: Option<&ServerName>, limit: Option<UInt>, since: Option<&str>, filter: &Filter, _network: &RoomNetwork,
|
||||
) -> Result<get_public_rooms_filtered::v3::Response> {
|
||||
if let Some(other_server) =
|
||||
server.filter(|server| *server != services().globals.server_name().as_str())
|
||||
{
|
||||
if let Some(other_server) = server.filter(|server| *server != services().globals.server_name().as_str()) {
|
||||
let response = services()
|
||||
.sending
|
||||
.send_federation_request(
|
||||
|
@ -181,12 +165,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
let backwards = match characters.next() {
|
||||
Some('n') => false,
|
||||
Some('p') => true,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid `since` token",
|
||||
))
|
||||
}
|
||||
_ => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid `since` token")),
|
||||
};
|
||||
|
||||
num_since = characters
|
||||
|
@ -214,9 +193,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
.map_or(Ok(None), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomCanonicalAliasEventContent| c.alias)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid canonical alias event in database.")
|
||||
})
|
||||
.map_err(|_| Error::bad_database("Invalid canonical alias event in database."))
|
||||
})?,
|
||||
name: services().rooms.state_accessor.get_name(&room_id)?,
|
||||
num_joined_members: services()
|
||||
|
@ -251,11 +228,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
.map(|c: RoomHistoryVisibilityEventContent| {
|
||||
c.history_visibility == HistoryVisibility::WorldReadable
|
||||
})
|
||||
.map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Invalid room history visibility event in database.",
|
||||
)
|
||||
})
|
||||
.map_err(|_| Error::bad_database("Invalid room history visibility event in database."))
|
||||
})?,
|
||||
guest_can_join: services()
|
||||
.rooms
|
||||
|
@ -263,12 +236,8 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
.room_state_get(&room_id, &StateEventType::RoomGuestAccess, "")?
|
||||
.map_or(Ok(false), |s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomGuestAccessEventContent| {
|
||||
c.guest_access == GuestAccess::CanJoin
|
||||
})
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid room guest access event in database.")
|
||||
})
|
||||
.map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin)
|
||||
.map_err(|_| Error::bad_database("Invalid room guest access event in database."))
|
||||
})?,
|
||||
avatar_url: services()
|
||||
.rooms
|
||||
|
@ -277,9 +246,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
.map(|s| {
|
||||
serde_json::from_str(s.content.get())
|
||||
.map(|c: RoomAvatarEventContent| c.url)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid room avatar event in database.")
|
||||
})
|
||||
.map_err(|_| Error::bad_database("Invalid room avatar event in database."))
|
||||
})
|
||||
.transpose()?
|
||||
// url is now an Option<String> so we must flatten
|
||||
|
@ -308,12 +275,10 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
|
||||
.map(|s| {
|
||||
serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err(
|
||||
|e| {
|
||||
serde_json::from_str::<RoomCreateEventContent>(s.content.get()).map_err(|e| {
|
||||
error!("Invalid room create event in database: {}", e);
|
||||
Error::BadDatabase("Invalid room create event in database.")
|
||||
},
|
||||
)
|
||||
})
|
||||
})
|
||||
.transpose()?
|
||||
.and_then(|e| e.room_type),
|
||||
|
@ -323,11 +288,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
})
|
||||
.filter_map(|r: Result<_>| r.ok()) // Filter out buggy rooms
|
||||
.filter(|chunk| {
|
||||
if let Some(query) = filter
|
||||
.generic_search_term
|
||||
.as_ref()
|
||||
.map(|q| q.to_lowercase())
|
||||
{
|
||||
if let Some(query) = filter.generic_search_term.as_ref().map(|q| q.to_lowercase()) {
|
||||
if let Some(name) = &chunk.name {
|
||||
if name.as_str().to_lowercase().contains(&query) {
|
||||
return true;
|
||||
|
@ -359,11 +320,7 @@ pub(crate) async fn get_public_rooms_filtered_helper(
|
|||
|
||||
let total_room_count_estimate = (all_rooms.len() as u32).into();
|
||||
|
||||
let chunk: Vec<_> = all_rooms
|
||||
.into_iter()
|
||||
.skip(num_since as usize)
|
||||
.take(limit as usize)
|
||||
.collect();
|
||||
let chunk: Vec<_> = all_rooms.into_iter().skip(num_since as usize).take(limit as usize).collect();
|
||||
|
||||
let prev_batch = if num_since == 0 {
|
||||
None
|
||||
|
|
|
@ -1,17 +1,16 @@
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
filter::{create_filter, get_filter},
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/user/{userId}/filter/{filterId}`
|
||||
///
|
||||
/// Loads a filter that was previously created.
|
||||
///
|
||||
/// - A user can only access their own filters
|
||||
pub async fn get_filter_route(
|
||||
body: Ruma<get_filter::v3::Request>,
|
||||
) -> Result<get_filter::v3::Response> {
|
||||
pub async fn get_filter_route(body: Ruma<get_filter::v3::Request>) -> Result<get_filter::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let filter = match services().users.get_filter(sender_user, &body.filter_id)? {
|
||||
Some(filter) => filter,
|
||||
|
@ -24,9 +23,7 @@ pub async fn get_filter_route(
|
|||
/// # `PUT /_matrix/client/r0/user/{userId}/filter`
|
||||
///
|
||||
/// Creates a new filter to be used by other endpoints.
|
||||
pub async fn create_filter_route(
|
||||
body: Ruma<create_filter::v3::Request>,
|
||||
) -> Result<create_filter::v3::Response> {
|
||||
pub async fn create_filter_route(body: Ruma<create_filter::v3::Request>) -> Result<create_filter::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
Ok(create_filter::v3::Response::new(
|
||||
services().users.create_filter(sender_user, &body.filter)?,
|
||||
|
|
|
@ -1,14 +1,14 @@
|
|||
use super::SESSION_ID_LENGTH;
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use std::{
|
||||
collections::{hash_map, BTreeMap, HashMap, HashSet},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
|
||||
use futures_util::{stream::FuturesUnordered, StreamExt};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{
|
||||
error::ErrorKind,
|
||||
keys::{
|
||||
claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures,
|
||||
upload_signing_keys,
|
||||
},
|
||||
keys::{claim_keys, get_key_changes, get_keys, upload_keys, upload_signatures, upload_signing_keys},
|
||||
uiaa::{AuthFlow, AuthType, UiaaInfo},
|
||||
},
|
||||
federation,
|
||||
|
@ -17,48 +17,36 @@ use ruma::{
|
|||
DeviceKeyAlgorithm, OwnedDeviceId, OwnedUserId, UserId,
|
||||
};
|
||||
use serde_json::json;
|
||||
use std::{
|
||||
collections::{hash_map, BTreeMap, HashMap, HashSet},
|
||||
time::{Duration, Instant},
|
||||
};
|
||||
use tracing::{debug, error};
|
||||
|
||||
use super::SESSION_ID_LENGTH;
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/r0/keys/upload`
|
||||
///
|
||||
/// Publish end-to-end encryption keys for the sender device.
|
||||
///
|
||||
/// - Adds one time keys
|
||||
/// - If there are no device keys yet: Adds device keys (TODO: merge with existing keys?)
|
||||
pub async fn upload_keys_route(
|
||||
body: Ruma<upload_keys::v3::Request>,
|
||||
) -> Result<upload_keys::v3::Response> {
|
||||
/// - If there are no device keys yet: Adds device keys (TODO: merge with
|
||||
/// existing keys?)
|
||||
pub async fn upload_keys_route(body: Ruma<upload_keys::v3::Request>) -> Result<upload_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_ref().expect("user is authenticated");
|
||||
|
||||
for (key_key, key_value) in &body.one_time_keys {
|
||||
services()
|
||||
.users
|
||||
.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
|
||||
services().users.add_one_time_key(sender_user, sender_device, key_key, key_value)?;
|
||||
}
|
||||
|
||||
if let Some(device_keys) = &body.device_keys {
|
||||
// TODO: merge this and the existing event?
|
||||
// This check is needed to assure that signatures are kept
|
||||
if services()
|
||||
.users
|
||||
.get_device_keys(sender_user, sender_device)?
|
||||
.is_none()
|
||||
{
|
||||
services()
|
||||
.users
|
||||
.add_device_keys(sender_user, sender_device, device_keys)?;
|
||||
if services().users.get_device_keys(sender_user, sender_device)?.is_none() {
|
||||
services().users.add_device_keys(sender_user, sender_device, device_keys)?;
|
||||
}
|
||||
}
|
||||
|
||||
Ok(upload_keys::v3::Response {
|
||||
one_time_key_counts: services()
|
||||
.users
|
||||
.count_one_time_keys(sender_user, sender_device)?,
|
||||
one_time_key_counts: services().users.count_one_time_keys(sender_user, sender_device)?,
|
||||
})
|
||||
}
|
||||
|
||||
|
@ -68,7 +56,8 @@ pub async fn upload_keys_route(
|
|||
///
|
||||
/// - Always fetches users from other servers over federation
|
||||
/// - Gets master keys, self-signing keys, user signing keys and device keys.
|
||||
/// - The master and self-signing keys contain signatures that the user is allowed to see
|
||||
/// - The master and self-signing keys contain signatures that the user is
|
||||
/// allowed to see
|
||||
pub async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_keys::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
|
@ -86,9 +75,7 @@ pub async fn get_keys_route(body: Ruma<get_keys::v3::Request>) -> Result<get_key
|
|||
/// # `POST /_matrix/client/r0/keys/claim`
|
||||
///
|
||||
/// Claims one-time keys
|
||||
pub async fn claim_keys_route(
|
||||
body: Ruma<claim_keys::v3::Request>,
|
||||
) -> Result<claim_keys::v3::Response> {
|
||||
pub async fn claim_keys_route(body: Ruma<claim_keys::v3::Request>) -> Result<claim_keys::v3::Response> {
|
||||
let response = claim_keys_helper(&body.one_time_keys).await?;
|
||||
|
||||
Ok(response)
|
||||
|
@ -117,19 +104,14 @@ pub async fn upload_signing_keys_route(
|
|||
};
|
||||
|
||||
if let Some(auth) = &body.auth {
|
||||
let (worked, uiaainfo) =
|
||||
services()
|
||||
.uiaa
|
||||
.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
let (worked, uiaainfo) = services().uiaa.try_auth(sender_user, sender_device, auth, &uiaainfo)?;
|
||||
if !worked {
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
}
|
||||
// Success!
|
||||
} else if let Some(json) = body.json_body {
|
||||
uiaainfo.session = Some(utils::random_string(SESSION_ID_LENGTH));
|
||||
services()
|
||||
.uiaa
|
||||
.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
services().uiaa.create(sender_user, sender_device, &uiaainfo, &json)?;
|
||||
return Err(Error::Uiaa(uiaainfo));
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::NotJson, "Not json."));
|
||||
|
@ -163,20 +145,11 @@ pub async fn upload_signatures_route(
|
|||
|
||||
for signature in key
|
||||
.get("signatures")
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Missing signatures field.",
|
||||
))?
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Missing signatures field."))?
|
||||
.get(sender_user.to_string())
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid user in signatures field.",
|
||||
))?
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid user in signatures field."))?
|
||||
.as_object()
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid signature.",
|
||||
))?
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature."))?
|
||||
.clone()
|
||||
.into_iter()
|
||||
{
|
||||
|
@ -186,15 +159,10 @@ pub async fn upload_signatures_route(
|
|||
signature
|
||||
.1
|
||||
.as_str()
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid signature value.",
|
||||
))?
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Invalid signature value."))?
|
||||
.to_owned(),
|
||||
);
|
||||
services()
|
||||
.users
|
||||
.sign_key(user_id, key_id, signature, sender_user)?;
|
||||
services().users.sign_key(user_id, key_id, signature, sender_user)?;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -206,12 +174,11 @@ pub async fn upload_signatures_route(
|
|||
|
||||
/// # `POST /_matrix/client/r0/keys/changes`
|
||||
///
|
||||
/// Gets a list of users who have updated their device identity keys since the previous sync token.
|
||||
/// Gets a list of users who have updated their device identity keys since the
|
||||
/// previous sync token.
|
||||
///
|
||||
/// - TODO: left users
|
||||
pub async fn get_key_changes_route(
|
||||
body: Ruma<get_key_changes::v3::Request>,
|
||||
) -> Result<get_key_changes::v3::Response> {
|
||||
pub async fn get_key_changes_route(body: Ruma<get_key_changes::v3::Request>) -> Result<get_key_changes::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut device_list_updates = HashSet::new();
|
||||
|
@ -221,35 +188,20 @@ pub async fn get_key_changes_route(
|
|||
.users
|
||||
.keys_changed(
|
||||
sender_user.as_str(),
|
||||
body.from
|
||||
.parse()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
|
||||
Some(
|
||||
body.to
|
||||
.parse()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?,
|
||||
),
|
||||
body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
|
||||
Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?),
|
||||
)
|
||||
.filter_map(std::result::Result::ok),
|
||||
);
|
||||
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
for room_id in services().rooms.state_cache.rooms_joined(sender_user).filter_map(std::result::Result::ok) {
|
||||
device_list_updates.extend(
|
||||
services()
|
||||
.users
|
||||
.keys_changed(
|
||||
room_id.as_ref(),
|
||||
body.from.parse().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`.")
|
||||
})?,
|
||||
Some(body.to.parse().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`.")
|
||||
})?),
|
||||
body.from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `from`."))?,
|
||||
Some(body.to.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid `to`."))?),
|
||||
)
|
||||
.filter_map(std::result::Result::ok),
|
||||
);
|
||||
|
@ -261,9 +213,7 @@ pub async fn get_key_changes_route(
|
|||
}
|
||||
|
||||
pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
||||
sender_user: Option<&UserId>,
|
||||
device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>,
|
||||
allowed_signatures: F,
|
||||
sender_user: Option<&UserId>, device_keys_input: &BTreeMap<OwnedUserId, Vec<OwnedDeviceId>>, allowed_signatures: F,
|
||||
include_display_names: bool,
|
||||
) -> Result<get_keys::v3::Response> {
|
||||
let mut master_keys = BTreeMap::new();
|
||||
|
@ -277,10 +227,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
let user_id: &UserId = user_id;
|
||||
|
||||
if user_id.server_name() != services().globals.server_name() {
|
||||
get_over_federation
|
||||
.entry(user_id.server_name())
|
||||
.or_insert_with(Vec::new)
|
||||
.push((user_id, device_ids));
|
||||
get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, device_ids));
|
||||
continue;
|
||||
}
|
||||
|
||||
|
@ -292,9 +239,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
let metadata = services()
|
||||
.users
|
||||
.get_device_metadata(user_id, &device_id)?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database("all_device_keys contained nonexistent device.")
|
||||
})?;
|
||||
.ok_or_else(|| Error::bad_database("all_device_keys contained nonexistent device."))?;
|
||||
|
||||
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
|
||||
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
||||
|
@ -307,13 +252,9 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
for device_id in device_ids {
|
||||
let mut container = BTreeMap::new();
|
||||
if let Some(mut keys) = services().users.get_device_keys(user_id, device_id)? {
|
||||
let metadata = services()
|
||||
.users
|
||||
.get_device_metadata(user_id, device_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Tried to get keys for nonexistent device.",
|
||||
))?;
|
||||
let metadata = services().users.get_device_metadata(user_id, device_id)?.ok_or(
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Tried to get keys for nonexistent device."),
|
||||
)?;
|
||||
|
||||
add_unsigned_device_display_name(&mut keys, metadata, include_display_names)
|
||||
.map_err(|_| Error::bad_database("invalid device keys in database"))?;
|
||||
|
@ -323,17 +264,11 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
}
|
||||
}
|
||||
|
||||
if let Some(master_key) =
|
||||
services()
|
||||
.users
|
||||
.get_master_key(sender_user, user_id, &allowed_signatures)?
|
||||
{
|
||||
if let Some(master_key) = services().users.get_master_key(sender_user, user_id, &allowed_signatures)? {
|
||||
master_keys.insert(user_id.to_owned(), master_key);
|
||||
}
|
||||
if let Some(self_signing_key) =
|
||||
services()
|
||||
.users
|
||||
.get_self_signing_key(sender_user, user_id, &allowed_signatures)?
|
||||
services().users.get_self_signing_key(sender_user, user_id, &allowed_signatures)?
|
||||
{
|
||||
self_signing_keys.insert(user_id.to_owned(), self_signing_key);
|
||||
}
|
||||
|
@ -346,29 +281,17 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
|
||||
let mut failures = BTreeMap::new();
|
||||
|
||||
let back_off = |id| match services()
|
||||
.globals
|
||||
.bad_query_ratelimiter
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(id)
|
||||
{
|
||||
let back_off = |id| match services().globals.bad_query_ratelimiter.write().unwrap().entry(id) {
|
||||
hash_map::Entry::Vacant(e) => {
|
||||
e.insert((Instant::now(), 1));
|
||||
}
|
||||
},
|
||||
hash_map::Entry::Occupied(mut e) => *e.get_mut() = (Instant::now(), e.get().1 + 1),
|
||||
};
|
||||
|
||||
let mut futures: FuturesUnordered<_> = get_over_federation
|
||||
.into_iter()
|
||||
.map(|(server, vec)| async move {
|
||||
if let Some((time, tries)) = services()
|
||||
.globals
|
||||
.bad_query_ratelimiter
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(server)
|
||||
{
|
||||
if let Some((time, tries)) = services().globals.bad_query_ratelimiter.read().unwrap().get(server) {
|
||||
// Exponential backoff
|
||||
let mut min_elapsed_duration = Duration::from_secs(5 * 60) * (*tries) * (*tries);
|
||||
if min_elapsed_duration > Duration::from_secs(60 * 60 * 24) {
|
||||
|
@ -377,10 +300,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
|
||||
if time.elapsed() < min_elapsed_duration {
|
||||
debug!("Backing off query from {:?}", server);
|
||||
return (
|
||||
server,
|
||||
Err(Error::BadServerResponse("bad query, still backing off")),
|
||||
);
|
||||
return (server, Err(Error::BadServerResponse("bad query, still backing off")));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -412,35 +332,31 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
match response {
|
||||
Ok(Ok(response)) => {
|
||||
for (user, masterkey) in response.master_keys {
|
||||
let (master_key_id, mut master_key) =
|
||||
services().users.parse_master_key(&user, &masterkey)?;
|
||||
let (master_key_id, mut master_key) = services().users.parse_master_key(&user, &masterkey)?;
|
||||
|
||||
if let Some(our_master_key) = services().users.get_key(
|
||||
&master_key_id,
|
||||
sender_user,
|
||||
&user,
|
||||
&allowed_signatures,
|
||||
)? {
|
||||
let (_, our_master_key) =
|
||||
services().users.parse_master_key(&user, &our_master_key)?;
|
||||
if let Some(our_master_key) =
|
||||
services().users.get_key(&master_key_id, sender_user, &user, &allowed_signatures)?
|
||||
{
|
||||
let (_, our_master_key) = services().users.parse_master_key(&user, &our_master_key)?;
|
||||
master_key.signatures.extend(our_master_key.signatures);
|
||||
}
|
||||
let json = serde_json::to_value(master_key).expect("to_value always works");
|
||||
let raw = serde_json::from_value(json).expect("Raw::from_value always works");
|
||||
services().users.add_cross_signing_keys(
|
||||
&user, &raw, &None, &None,
|
||||
false, // Dont notify. A notification would trigger another key request resulting in an endless loop
|
||||
false, /* Dont notify. A notification would trigger another key request resulting in an
|
||||
* endless loop */
|
||||
)?;
|
||||
master_keys.insert(user, raw);
|
||||
}
|
||||
|
||||
self_signing_keys.extend(response.self_signing_keys);
|
||||
device_keys.extend(response.device_keys);
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
back_off(server.to_owned());
|
||||
failures.insert(server.to_string(), json!({}));
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -454,8 +370,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
|
|||
}
|
||||
|
||||
fn add_unsigned_device_display_name(
|
||||
keys: &mut Raw<ruma::encryption::DeviceKeys>,
|
||||
metadata: ruma::api::client::device::Device,
|
||||
keys: &mut Raw<ruma::encryption::DeviceKeys>, metadata: ruma::api::client::device::Device,
|
||||
include_display_names: bool,
|
||||
) -> serde_json::Result<()> {
|
||||
if let Some(display_name) = metadata.display_name {
|
||||
|
@ -488,19 +403,12 @@ pub(crate) async fn claim_keys_helper(
|
|||
|
||||
for (user_id, map) in one_time_keys_input {
|
||||
if user_id.server_name() != services().globals.server_name() {
|
||||
get_over_federation
|
||||
.entry(user_id.server_name())
|
||||
.or_insert_with(Vec::new)
|
||||
.push((user_id, map));
|
||||
get_over_federation.entry(user_id.server_name()).or_insert_with(Vec::new).push((user_id, map));
|
||||
}
|
||||
|
||||
let mut container = BTreeMap::new();
|
||||
for (device_id, key_algorithm) in map {
|
||||
if let Some(one_time_keys) =
|
||||
services()
|
||||
.users
|
||||
.take_one_time_key(user_id, device_id, key_algorithm)?
|
||||
{
|
||||
if let Some(one_time_keys) = services().users.take_one_time_key(user_id, device_id, key_algorithm)? {
|
||||
let mut c = BTreeMap::new();
|
||||
c.insert(one_time_keys.0, one_time_keys.1);
|
||||
container.insert(device_id.clone(), c);
|
||||
|
@ -537,10 +445,10 @@ pub(crate) async fn claim_keys_helper(
|
|||
match response {
|
||||
Ok(keys) => {
|
||||
one_time_keys.extend(keys.one_time_keys);
|
||||
}
|
||||
},
|
||||
Err(_e) => {
|
||||
failures.insert(server.to_string(), json!({}));
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,21 +1,21 @@
|
|||
use std::{io::Cursor, net::IpAddr, sync::Arc, time::Duration};
|
||||
|
||||
use image::io::Reader as ImgReader;
|
||||
use reqwest::Url;
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
media::{
|
||||
create_content, get_content, get_content_as_filename, get_content_thumbnail, get_media_config,
|
||||
get_media_preview,
|
||||
},
|
||||
};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use webpage::HTML;
|
||||
|
||||
use crate::{
|
||||
service::media::{FileMeta, UrlPreviewData},
|
||||
services, utils, Error, Result, Ruma,
|
||||
};
|
||||
use image::io::Reader as ImgReader;
|
||||
|
||||
use reqwest::Url;
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
media::{
|
||||
create_content, get_content, get_content_as_filename, get_content_thumbnail,
|
||||
get_media_config, get_media_preview,
|
||||
},
|
||||
};
|
||||
use tracing::{debug, error, info, warn};
|
||||
use webpage::HTML;
|
||||
|
||||
/// generated MXC ID (`media-id`) length
|
||||
const MXC_LENGTH: usize = 32;
|
||||
|
@ -39,22 +39,13 @@ pub async fn get_media_preview_route(
|
|||
) -> Result<get_media_preview::v3::Response> {
|
||||
let url = &body.url;
|
||||
if !url_preview_allowed(url) {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"URL is not allowed to be previewed",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "URL is not allowed to be previewed"));
|
||||
}
|
||||
|
||||
if let Ok(preview) = get_url_preview(url).await {
|
||||
let res = serde_json::value::to_raw_value(&preview).map_err(|e| {
|
||||
error!(
|
||||
"Failed to convert UrlPreviewData into a serde json value: {}",
|
||||
e
|
||||
);
|
||||
Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Unknown error occurred parsing URL preview",
|
||||
)
|
||||
error!("Failed to convert UrlPreviewData into a serde json value: {}", e);
|
||||
Error::BadRequest(ErrorKind::Unknown, "Unknown error occurred parsing URL preview")
|
||||
})?;
|
||||
|
||||
return Ok(get_media_preview::v3::Response::from_raw_value(res));
|
||||
|
@ -74,9 +65,7 @@ pub async fn get_media_preview_route(
|
|||
///
|
||||
/// - Some metadata will be saved in the database
|
||||
/// - Media will be saved in the media/ directory
|
||||
pub async fn create_content_route(
|
||||
body: Ruma<create_content::v3::Request>,
|
||||
) -> Result<create_content::v3::Response> {
|
||||
pub async fn create_content_route(body: Ruma<create_content::v3::Request>) -> Result<create_content::v3::Response> {
|
||||
let mxc = format!(
|
||||
"mxc://{}/{}",
|
||||
services().globals.server_name(),
|
||||
|
@ -87,10 +76,7 @@ pub async fn create_content_route(
|
|||
.media
|
||||
.create(
|
||||
mxc.clone(),
|
||||
body.filename
|
||||
.as_ref()
|
||||
.map(|filename| "inline; filename=".to_owned() + filename)
|
||||
.as_deref(),
|
||||
body.filename.as_ref().map(|filename| "inline; filename=".to_owned() + filename).as_deref(),
|
||||
body.content_type.as_deref(),
|
||||
&body.file,
|
||||
)
|
||||
|
@ -106,20 +92,15 @@ pub async fn create_content_route(
|
|||
|
||||
/// helper method to fetch remote media from other servers over federation
|
||||
pub async fn get_remote_content(
|
||||
mxc: &str,
|
||||
server_name: &ruma::ServerName,
|
||||
media_id: String,
|
||||
allow_redirect: bool,
|
||||
timeout_ms: Duration,
|
||||
mxc: &str, server_name: &ruma::ServerName, media_id: String, allow_redirect: bool, timeout_ms: Duration,
|
||||
) -> Result<get_content::v3::Response, Error> {
|
||||
// we'll lie to the client and say the blocked server's media was not found and log.
|
||||
// the client has no way of telling anyways so this is a security bonus.
|
||||
if services()
|
||||
.globals
|
||||
.prevent_media_downloads_from()
|
||||
.contains(&server_name.to_owned())
|
||||
{
|
||||
info!("Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc);
|
||||
// we'll lie to the client and say the blocked server's media was not found and
|
||||
// log. the client has no way of telling anyways so this is a security bonus.
|
||||
if services().globals.prevent_media_downloads_from().contains(&server_name.to_owned()) {
|
||||
info!(
|
||||
"Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
|
||||
mxc
|
||||
);
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
}
|
||||
|
||||
|
@ -156,10 +137,9 @@ pub async fn get_remote_content(
|
|||
///
|
||||
/// - Only allows federation if `allow_remote` is true
|
||||
/// - Only redirects if `allow_redirect` is true
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds
|
||||
pub async fn get_content_route(
|
||||
body: Ruma<get_content::v3::Request>,
|
||||
) -> Result<get_content::v3::Response> {
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20
|
||||
/// seconds
|
||||
pub async fn get_content_route(body: Ruma<get_content::v3::Request>) -> Result<get_content::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
||||
if let Some(FileMeta {
|
||||
|
@ -195,14 +175,17 @@ pub async fn get_content_route(
|
|||
///
|
||||
/// - Only allows federation if `allow_remote` is true
|
||||
/// - Only redirects if `allow_redirect` is true
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20
|
||||
/// seconds
|
||||
pub async fn get_content_as_filename_route(
|
||||
body: Ruma<get_content_as_filename::v3::Request>,
|
||||
) -> Result<get_content_as_filename::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
||||
if let Some(FileMeta {
|
||||
content_type, file, ..
|
||||
content_type,
|
||||
file,
|
||||
..
|
||||
}) = services().media.get(mxc.clone()).await?
|
||||
{
|
||||
Ok(get_content_as_filename::v3::Response {
|
||||
|
@ -238,24 +221,23 @@ pub async fn get_content_as_filename_route(
|
|||
///
|
||||
/// - Only allows federation if `allow_remote` is true
|
||||
/// - Only redirects if `allow_redirect` is true
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20 seconds
|
||||
/// - Uses client-provided `timeout_ms` if available, else defaults to 20
|
||||
/// seconds
|
||||
pub async fn get_content_thumbnail_route(
|
||||
body: Ruma<get_content_thumbnail::v3::Request>,
|
||||
) -> Result<get_content_thumbnail::v3::Response> {
|
||||
let mxc = format!("mxc://{}/{}", body.server_name, body.media_id);
|
||||
|
||||
if let Some(FileMeta {
|
||||
content_type, file, ..
|
||||
content_type,
|
||||
file,
|
||||
..
|
||||
}) = services()
|
||||
.media
|
||||
.get_thumbnail(
|
||||
mxc.clone(),
|
||||
body.width
|
||||
.try_into()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
|
||||
body.height
|
||||
.try_into()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
|
||||
body.width.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Width is invalid."))?,
|
||||
body.height.try_into().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Height is invalid."))?,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
|
@ -265,14 +247,13 @@ pub async fn get_content_thumbnail_route(
|
|||
cross_origin_resource_policy: Some("cross-origin".to_owned()),
|
||||
})
|
||||
} else if &*body.server_name != services().globals.server_name() && body.allow_remote {
|
||||
// we'll lie to the client and say the blocked server's media was not found and log.
|
||||
// the client has no way of telling anyways so this is a security bonus.
|
||||
if services()
|
||||
.globals
|
||||
.prevent_media_downloads_from()
|
||||
.contains(&body.server_name.clone())
|
||||
{
|
||||
info!("Received request for remote media `{}` but server is in our media server blocklist. Returning 404.", mxc);
|
||||
// we'll lie to the client and say the blocked server's media was not found and
|
||||
// log. the client has no way of telling anyways so this is a security bonus.
|
||||
if services().globals.prevent_media_downloads_from().contains(&body.server_name.clone()) {
|
||||
info!(
|
||||
"Received request for remote media `{}` but server is in our media server blocklist. Returning 404.",
|
||||
mxc
|
||||
);
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Media not found."));
|
||||
}
|
||||
|
||||
|
@ -319,10 +300,7 @@ async fn download_image(client: &reqwest::Client, url: &str) -> Result<UrlPrevie
|
|||
utils::random_string(MXC_LENGTH)
|
||||
);
|
||||
|
||||
services()
|
||||
.media
|
||||
.create(mxc.clone(), None, None, &image)
|
||||
.await?;
|
||||
services().media.create(mxc.clone(), None, None, &image).await?;
|
||||
|
||||
let (width, height) = match ImgReader::new(Cursor::new(&image)).with_guessed_format() {
|
||||
Err(_) => (None, None),
|
||||
|
@ -348,19 +326,19 @@ async fn download_html(client: &reqwest::Client, url: &str) -> Result<UrlPreview
|
|||
while let Some(chunk) = response.chunk().await? {
|
||||
bytes.extend_from_slice(&chunk);
|
||||
if bytes.len() > services().globals.url_preview_max_spider_size() {
|
||||
debug!("Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the response body and assuming our necessary data is in this range.", url, services().globals.url_preview_max_spider_size());
|
||||
debug!(
|
||||
"Response body from URL {} exceeds url_preview_max_spider_size ({}), not processing the rest of the \
|
||||
response body and assuming our necessary data is in this range.",
|
||||
url,
|
||||
services().globals.url_preview_max_spider_size()
|
||||
);
|
||||
break;
|
||||
}
|
||||
}
|
||||
let body = String::from_utf8_lossy(&bytes);
|
||||
let html = match HTML::from_string(body.to_string(), Some(url.to_owned())) {
|
||||
Ok(html) => html,
|
||||
Err(_) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Failed to parse HTML",
|
||||
))
|
||||
}
|
||||
Err(_) => return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to parse HTML")),
|
||||
};
|
||||
|
||||
let mut data = match html.opengraph.images.first() {
|
||||
|
@ -399,7 +377,7 @@ fn url_request_allowed(addr: &IpAddr) -> bool {
|
|||
|| (ip4.octets()[0] == 198 && (ip4.octets()[1] & 0xfe) == 18) // is_benchmarking()
|
||||
|| (ip4.octets()[0] & 240 == 240 && !ip4.is_broadcast()) // is_reserved()
|
||||
|| ip4.is_broadcast())
|
||||
}
|
||||
},
|
||||
IpAddr::V6(ip6) => {
|
||||
!(ip6.is_unspecified()
|
||||
|| ip6.is_loopback()
|
||||
|
@ -426,7 +404,7 @@ fn url_request_allowed(addr: &IpAddr) -> bool {
|
|||
|| ((ip6.segments()[0] == 0x2001) && (ip6.segments()[1] == 0xdb8)) // is_documentation()
|
||||
|| ((ip6.segments()[0] & 0xfe00) == 0xfc00) // is_unique_local()
|
||||
|| ((ip6.segments()[0] & 0xffc0) == 0xfe80)) // is_unicast_link_local
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -434,38 +412,21 @@ async fn request_url_preview(url: &str) -> Result<UrlPreviewData> {
|
|||
let client = services().globals.url_preview_client();
|
||||
let response = client.head(url).send().await?;
|
||||
|
||||
if !response
|
||||
.remote_addr()
|
||||
.map_or(false, |a| url_request_allowed(&a.ip()))
|
||||
{
|
||||
if !response.remote_addr().map_or(false, |a| url_request_allowed(&a.ip())) {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Requesting from this address is forbidden",
|
||||
));
|
||||
}
|
||||
|
||||
let content_type = match response
|
||||
.headers()
|
||||
.get(reqwest::header::CONTENT_TYPE)
|
||||
.and_then(|x| x.to_str().ok())
|
||||
{
|
||||
let content_type = match response.headers().get(reqwest::header::CONTENT_TYPE).and_then(|x| x.to_str().ok()) {
|
||||
Some(ct) => ct,
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Unknown Content-Type",
|
||||
))
|
||||
}
|
||||
None => return Err(Error::BadRequest(ErrorKind::Unknown, "Unknown Content-Type")),
|
||||
};
|
||||
let data = match content_type {
|
||||
html if html.starts_with("text/html") => download_html(&client, url).await?,
|
||||
img if img.starts_with("image/") => download_image(&client, url).await?,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Unsupported Content-Type",
|
||||
))
|
||||
}
|
||||
_ => return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported Content-Type")),
|
||||
};
|
||||
|
||||
services().media.set_url_preview(url, &data).await?;
|
||||
|
@ -479,15 +440,8 @@ async fn get_url_preview(url: &str) -> Result<UrlPreviewData> {
|
|||
}
|
||||
|
||||
// ensure that only one request is made per URL
|
||||
let mutex_request = Arc::clone(
|
||||
services()
|
||||
.media
|
||||
.url_preview_mutex
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(url.to_owned())
|
||||
.or_default(),
|
||||
);
|
||||
let mutex_request =
|
||||
Arc::clone(services().media.url_preview_mutex.write().unwrap().entry(url.to_owned()).or_default());
|
||||
let _request_lock = mutex_request.lock().await;
|
||||
|
||||
match services().media.get_url_preview(url).await {
|
||||
|
@ -502,25 +456,19 @@ fn url_preview_allowed(url_str: &str) -> bool {
|
|||
Err(e) => {
|
||||
warn!("Failed to parse URL from a str: {}", e);
|
||||
return false;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
if ["http", "https"]
|
||||
.iter()
|
||||
.all(|&scheme| scheme != url.scheme().to_lowercase())
|
||||
{
|
||||
if ["http", "https"].iter().all(|&scheme| scheme != url.scheme().to_lowercase()) {
|
||||
debug!("Ignoring non-HTTP/HTTPS URL to preview: {}", url);
|
||||
return false;
|
||||
}
|
||||
|
||||
let host = match url.host_str() {
|
||||
None => {
|
||||
debug!(
|
||||
"Ignoring URL preview for a URL that does not have a host (?): {}",
|
||||
url
|
||||
);
|
||||
debug!("Ignoring URL preview for a URL that does not have a host (?): {}", url);
|
||||
return false;
|
||||
}
|
||||
},
|
||||
Some(h) => h.to_owned(),
|
||||
};
|
||||
|
||||
|
@ -532,41 +480,23 @@ fn url_preview_allowed(url_str: &str) -> bool {
|
|||
|| allowlist_domain_explicit.contains(&"*".to_owned())
|
||||
|| allowlist_url_contains.contains(&"*".to_owned())
|
||||
{
|
||||
debug!(
|
||||
"Config key contains * which is allowing all URL previews. Allowing URL {}",
|
||||
url
|
||||
);
|
||||
debug!("Config key contains * which is allowing all URL previews. Allowing URL {}", url);
|
||||
return true;
|
||||
}
|
||||
|
||||
if !host.is_empty() {
|
||||
if allowlist_domain_explicit.contains(&host) {
|
||||
debug!(
|
||||
"Host {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)",
|
||||
&host
|
||||
);
|
||||
debug!("Host {} is allowed by url_preview_domain_explicit_allowlist (check 1/3)", &host);
|
||||
return true;
|
||||
}
|
||||
|
||||
if allowlist_domain_contains
|
||||
.iter()
|
||||
.any(|domain_s| domain_s.contains(&host.clone()))
|
||||
{
|
||||
debug!(
|
||||
"Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
|
||||
&host
|
||||
);
|
||||
if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&host.clone())) {
|
||||
debug!("Host {} is allowed by url_preview_domain_contains_allowlist (check 2/3)", &host);
|
||||
return true;
|
||||
}
|
||||
|
||||
if allowlist_url_contains
|
||||
.iter()
|
||||
.any(|url_s| url.to_string().contains(&url_s.to_string()))
|
||||
{
|
||||
debug!(
|
||||
"URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)",
|
||||
&host
|
||||
);
|
||||
if allowlist_url_contains.iter().any(|url_s| url.to_string().contains(&url_s.to_string())) {
|
||||
debug!("URL {} is allowed by url_preview_url_contains_allowlist (check 3/3)", &host);
|
||||
return true;
|
||||
}
|
||||
|
||||
|
@ -584,17 +514,14 @@ fn url_preview_allowed(url_str: &str) -> bool {
|
|||
return true;
|
||||
}
|
||||
|
||||
if allowlist_domain_contains
|
||||
.iter()
|
||||
.any(|domain_s| domain_s.contains(&root_domain.to_owned()))
|
||||
{
|
||||
if allowlist_domain_contains.iter().any(|domain_s| domain_s.contains(&root_domain.to_owned())) {
|
||||
debug!(
|
||||
"Root domain {} is allowed by url_preview_domain_contains_allowlist (check 2/3)",
|
||||
&root_domain
|
||||
);
|
||||
return true;
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,7 +1,8 @@
|
|||
use crate::{
|
||||
service::{pdu::PduBuilder, rooms::timeline::PduCount},
|
||||
services, utils, Error, Result, Ruma,
|
||||
use std::{
|
||||
collections::{BTreeMap, HashSet},
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
|
@ -10,47 +11,40 @@ use ruma::{
|
|||
events::{StateEventType, TimelineEventType},
|
||||
};
|
||||
use serde_json::from_str;
|
||||
use std::{
|
||||
collections::{BTreeMap, HashSet},
|
||||
sync::Arc,
|
||||
|
||||
use crate::{
|
||||
service::{pdu::PduBuilder, rooms::timeline::PduCount},
|
||||
services, utils, Error, Result, Ruma,
|
||||
};
|
||||
|
||||
/// # `PUT /_matrix/client/v3/rooms/{roomId}/send/{eventType}/{txnId}`
|
||||
///
|
||||
/// Send a message event into the room.
|
||||
///
|
||||
/// - Is a NOOP if the txn id was already used before and returns the same event id again
|
||||
/// - Is a NOOP if the txn id was already used before and returns the same event
|
||||
/// id again
|
||||
/// - The only requirement for the content is that it has to be valid json
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is
|
||||
/// allowed
|
||||
pub async fn send_message_event_route(
|
||||
body: Ruma<send_message_event::v3::Request>,
|
||||
) -> Result<send_message_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let sender_device = body.sender_device.as_deref();
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(body.room_id.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
// Forbid m.room.encrypted if encryption is disabled
|
||||
if TimelineEventType::RoomEncrypted == body.event_type.to_string().into()
|
||||
&& !services().globals.allow_encryption()
|
||||
if TimelineEventType::RoomEncrypted == body.event_type.to_string().into() && !services().globals.allow_encryption()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Encryption has been disabled",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Encryption has been disabled"));
|
||||
}
|
||||
|
||||
// certain event types require certain fields to be valid in request bodies.
|
||||
// this helps prevent attempting to handle events that we can't deserialise later so don't waste resources on it.
|
||||
// this helps prevent attempting to handle events that we can't deserialise
|
||||
// later so don't waste resources on it.
|
||||
//
|
||||
// see https://spec.matrix.org/v1.9/client-server-api/#events-2 for what's required per event type.
|
||||
match body.event_type.to_string().into() {
|
||||
|
@ -71,7 +65,7 @@ pub async fn send_message_event_route(
|
|||
"'msgtype' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
}
|
||||
},
|
||||
TimelineEventType::RoomName => {
|
||||
let name_field = body.body.body.get_field::<String>("name");
|
||||
|
||||
|
@ -81,7 +75,7 @@ pub async fn send_message_event_route(
|
|||
"'name' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
}
|
||||
},
|
||||
TimelineEventType::RoomTopic => {
|
||||
let topic_field = body.body.body.get_field::<String>("topic");
|
||||
|
||||
|
@ -91,16 +85,12 @@ pub async fn send_message_event_route(
|
|||
"'topic' field in JSON request is invalid",
|
||||
));
|
||||
}
|
||||
}
|
||||
_ => {} // event may be custom/experimental or can be empty don't do anything with it
|
||||
},
|
||||
_ => {}, // event may be custom/experimental or can be empty don't do anything with it
|
||||
};
|
||||
|
||||
// Check if this is a new transaction id
|
||||
if let Some(response) =
|
||||
services()
|
||||
.transaction_ids
|
||||
.existing_txnid(sender_user, sender_device, &body.txn_id)?
|
||||
{
|
||||
if let Some(response) = services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)? {
|
||||
// The client might have sent a txnid of the /sendToDevice endpoint
|
||||
// This txnid has no response associated with it
|
||||
if response.is_empty() {
|
||||
|
@ -114,7 +104,9 @@ pub async fn send_message_event_route(
|
|||
.map_err(|_| Error::bad_database("Invalid txnid bytes in database."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid event id in txnid data."))?;
|
||||
return Ok(send_message_event::v3::Response { event_id });
|
||||
return Ok(send_message_event::v3::Response {
|
||||
event_id,
|
||||
});
|
||||
}
|
||||
|
||||
let mut unsigned = BTreeMap::new();
|
||||
|
@ -138,25 +130,19 @@ pub async fn send_message_event_route(
|
|||
)
|
||||
.await?;
|
||||
|
||||
services().transaction_ids.add_txnid(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.txn_id,
|
||||
event_id.as_bytes(),
|
||||
)?;
|
||||
services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, event_id.as_bytes())?;
|
||||
|
||||
drop(state_lock);
|
||||
|
||||
Ok(send_message_event::v3::Response::new(
|
||||
(*event_id).to_owned(),
|
||||
))
|
||||
Ok(send_message_event::v3::Response::new((*event_id).to_owned()))
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/messages`
|
||||
///
|
||||
/// Allows paginating through room history.
|
||||
///
|
||||
/// - Only works if the user is joined (TODO: always allow, but only show events where the user was
|
||||
/// - Only works if the user is joined (TODO: always allow, but only show events
|
||||
/// where the user was
|
||||
/// joined, depending on history_visibility)
|
||||
pub async fn get_message_events_route(
|
||||
body: Ruma<get_message_events::v3::Request>,
|
||||
|
@ -172,17 +158,9 @@ pub async fn get_message_events_route(
|
|||
},
|
||||
};
|
||||
|
||||
let to = body
|
||||
.to
|
||||
.as_ref()
|
||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
||||
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||
|
||||
services().rooms.lazy_loading.lazy_load_confirm_delivery(
|
||||
sender_user,
|
||||
sender_device,
|
||||
&body.room_id,
|
||||
from,
|
||||
)?;
|
||||
services().rooms.lazy_loading.lazy_load_confirm_delivery(sender_user, sender_device, &body.room_id, from)?;
|
||||
|
||||
let limit = u64::from(body.limit).min(100) as usize;
|
||||
|
||||
|
@ -228,21 +206,14 @@ pub async fn get_message_events_route(
|
|||
|
||||
next_token = events_after.last().map(|(count, _)| count).copied();
|
||||
|
||||
let events_after: Vec<_> = events_after
|
||||
.into_iter()
|
||||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect();
|
||||
let events_after: Vec<_> = events_after.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||
|
||||
resp.start = from.stringify();
|
||||
resp.end = next_token.map(|count| count.stringify());
|
||||
resp.chunk = events_after;
|
||||
}
|
||||
},
|
||||
ruma::api::Direction::Backward => {
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.backfill_if_required(&body.room_id, from)
|
||||
.await?;
|
||||
services().rooms.timeline.backfill_if_required(&body.room_id, from).await?;
|
||||
let events_before: Vec<_> = services()
|
||||
.rooms
|
||||
.timeline
|
||||
|
@ -277,15 +248,12 @@ pub async fn get_message_events_route(
|
|||
|
||||
next_token = events_before.last().map(|(count, _)| count).copied();
|
||||
|
||||
let events_before: Vec<_> = events_before
|
||||
.into_iter()
|
||||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect();
|
||||
let events_before: Vec<_> = events_before.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect();
|
||||
|
||||
resp.start = from.stringify();
|
||||
resp.end = next_token.map(|count| count.stringify());
|
||||
resp.chunk = events_before;
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
resp.state = Vec::new();
|
||||
|
|
|
@ -1,21 +1,18 @@
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
use std::time::Duration;
|
||||
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
presence::{get_presence, set_presence},
|
||||
};
|
||||
use std::time::Duration;
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/presence/{userId}/status`
|
||||
///
|
||||
/// Sets the presence state of the sender user.
|
||||
pub async fn set_presence_route(
|
||||
body: Ruma<set_presence::v3::Request>,
|
||||
) -> Result<set_presence::v3::Response> {
|
||||
pub async fn set_presence_route(body: Ruma<set_presence::v3::Request>) -> Result<set_presence::v3::Response> {
|
||||
if !services().globals.allow_local_presence() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Presence is disabled on this server",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Presence is disabled on this server"));
|
||||
}
|
||||
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
@ -40,33 +37,19 @@ pub async fn set_presence_route(
|
|||
/// Gets the presence state of the given user.
|
||||
///
|
||||
/// - Only works if you share a room with the user
|
||||
pub async fn get_presence_route(
|
||||
body: Ruma<get_presence::v3::Request>,
|
||||
) -> Result<get_presence::v3::Response> {
|
||||
pub async fn get_presence_route(body: Ruma<get_presence::v3::Request>) -> Result<get_presence::v3::Response> {
|
||||
if !services().globals.allow_local_presence() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Presence is disabled on this server",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Presence is disabled on this server"));
|
||||
}
|
||||
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let mut presence_event = None;
|
||||
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.user
|
||||
.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])?
|
||||
{
|
||||
for room_id in services().rooms.user.get_shared_rooms(vec![sender_user.clone(), body.user_id.clone()])? {
|
||||
let room_id = room_id?;
|
||||
|
||||
if let Some(presence) = services()
|
||||
.rooms
|
||||
.edus
|
||||
.presence
|
||||
.get_presence(&room_id, sender_user)?
|
||||
{
|
||||
if let Some(presence) = services().rooms.edus.presence.get_presence(&room_id, sender_user)? {
|
||||
presence_event = Some(presence);
|
||||
break;
|
||||
}
|
||||
|
@ -77,10 +60,7 @@ pub async fn get_presence_route(
|
|||
// TODO: Should ruma just use the presenceeventcontent type here?
|
||||
status_msg: presence.content.status_msg,
|
||||
currently_active: presence.content.currently_active,
|
||||
last_active_ago: presence
|
||||
.content
|
||||
.last_active_ago
|
||||
.map(|millis| Duration::from_millis(millis.into())),
|
||||
last_active_ago: presence.content.last_active_ago.map(|millis| Duration::from_millis(millis.into())),
|
||||
presence: presence.content.presence,
|
||||
})
|
||||
} else {
|
||||
|
|
|
@ -4,9 +4,7 @@ use ruma::{
|
|||
api::{
|
||||
client::{
|
||||
error::ErrorKind,
|
||||
profile::{
|
||||
get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name,
|
||||
},
|
||||
profile::{get_avatar_url, get_display_name, get_profile, set_avatar_url, set_display_name},
|
||||
},
|
||||
federation,
|
||||
},
|
||||
|
@ -27,10 +25,7 @@ pub async fn set_displayname_route(
|
|||
) -> Result<set_display_name::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(sender_user, body.displayname.clone())
|
||||
.await?;
|
||||
services().users.set_displayname(sender_user, body.displayname.clone()).await?;
|
||||
|
||||
// Send a new membership event and presence update into all joined rooms
|
||||
let all_rooms_joined: Vec<_> = services()
|
||||
|
@ -48,16 +43,9 @@ pub async fn set_displayname_route(
|
|||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(
|
||||
&room_id,
|
||||
&StateEventType::RoomMember,
|
||||
sender_user.as_str(),
|
||||
)?
|
||||
.room_state_get(&room_id, &StateEventType::RoomMember, sender_user.as_str())?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database(
|
||||
"Tried to send displayname update for user not in the \
|
||||
room.",
|
||||
)
|
||||
Error::bad_database("Tried to send displayname update for user not in the room.")
|
||||
})?
|
||||
.content
|
||||
.get(),
|
||||
|
@ -76,31 +64,16 @@ pub async fn set_displayname_route(
|
|||
.collect();
|
||||
|
||||
for (pdu_builder, room_id) in all_rooms_joined {
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(room_id.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let _ = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
|
||||
.await;
|
||||
let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await;
|
||||
}
|
||||
|
||||
if services().globals.allow_local_presence() {
|
||||
// Presence update
|
||||
services()
|
||||
.rooms
|
||||
.edus
|
||||
.presence
|
||||
.ping_presence(sender_user, PresenceState::Online)?;
|
||||
services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?;
|
||||
}
|
||||
|
||||
Ok(set_display_name::v3::Response {})
|
||||
|
@ -132,18 +105,9 @@ pub async fn get_displayname_route(
|
|||
services().users.create(&body.user_id, None)?;
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(&body.user_id, response.displayname.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_avatar_url(&body.user_id, response.avatar_url.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_blurhash(&body.user_id, response.blurhash.clone())
|
||||
.await?;
|
||||
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
|
||||
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
|
||||
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
|
||||
|
||||
return Ok(get_display_name::v3::Response {
|
||||
displayname: response.displayname,
|
||||
|
@ -152,11 +116,9 @@ pub async fn get_displayname_route(
|
|||
}
|
||||
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over federation
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Profile was not found.",
|
||||
));
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over
|
||||
// federation
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
|
||||
}
|
||||
|
||||
Ok(get_display_name::v3::Response {
|
||||
|
@ -169,20 +131,12 @@ pub async fn get_displayname_route(
|
|||
/// Updates the avatar_url and blurhash.
|
||||
///
|
||||
/// - Also makes sure other users receive the update using presence EDUs
|
||||
pub async fn set_avatar_url_route(
|
||||
body: Ruma<set_avatar_url::v3::Request>,
|
||||
) -> Result<set_avatar_url::v3::Response> {
|
||||
pub async fn set_avatar_url_route(body: Ruma<set_avatar_url::v3::Request>) -> Result<set_avatar_url::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_avatar_url(sender_user, body.avatar_url.clone())
|
||||
.await?;
|
||||
services().users.set_avatar_url(sender_user, body.avatar_url.clone()).await?;
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_blurhash(sender_user, body.blurhash.clone())
|
||||
.await?;
|
||||
services().users.set_blurhash(sender_user, body.blurhash.clone()).await?;
|
||||
|
||||
// Send a new membership event and presence update into all joined rooms
|
||||
let all_joined_rooms: Vec<_> = services()
|
||||
|
@ -200,16 +154,9 @@ pub async fn set_avatar_url_route(
|
|||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(
|
||||
&room_id,
|
||||
&StateEventType::RoomMember,
|
||||
sender_user.as_str(),
|
||||
)?
|
||||
.room_state_get(&room_id, &StateEventType::RoomMember, sender_user.as_str())?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database(
|
||||
"Tried to send displayname update for user not in the \
|
||||
room.",
|
||||
)
|
||||
Error::bad_database("Tried to send displayname update for user not in the room.")
|
||||
})?
|
||||
.content
|
||||
.get(),
|
||||
|
@ -228,31 +175,16 @@ pub async fn set_avatar_url_route(
|
|||
.collect();
|
||||
|
||||
for (pdu_builder, room_id) in all_joined_rooms {
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(room_id.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let _ = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
|
||||
.await;
|
||||
let _ = services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await;
|
||||
}
|
||||
|
||||
if services().globals.allow_local_presence() {
|
||||
// Presence update
|
||||
services()
|
||||
.rooms
|
||||
.edus
|
||||
.presence
|
||||
.ping_presence(sender_user, PresenceState::Online)?;
|
||||
services().rooms.edus.presence.ping_presence(sender_user, PresenceState::Online)?;
|
||||
}
|
||||
|
||||
Ok(set_avatar_url::v3::Response {})
|
||||
|
@ -264,9 +196,7 @@ pub async fn set_avatar_url_route(
|
|||
///
|
||||
/// - If user is on another server and we do not have a local copy already
|
||||
/// fetch avatar_url and blurhash over federation
|
||||
pub async fn get_avatar_url_route(
|
||||
body: Ruma<get_avatar_url::v3::Request>,
|
||||
) -> Result<get_avatar_url::v3::Response> {
|
||||
pub async fn get_avatar_url_route(body: Ruma<get_avatar_url::v3::Request>) -> Result<get_avatar_url::v3::Response> {
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
// Create and update our local copy of the user
|
||||
if let Ok(response) = services()
|
||||
|
@ -284,18 +214,9 @@ pub async fn get_avatar_url_route(
|
|||
services().users.create(&body.user_id, None)?;
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(&body.user_id, response.displayname.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_avatar_url(&body.user_id, response.avatar_url.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_blurhash(&body.user_id, response.blurhash.clone())
|
||||
.await?;
|
||||
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
|
||||
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
|
||||
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
|
||||
|
||||
return Ok(get_avatar_url::v3::Response {
|
||||
avatar_url: response.avatar_url,
|
||||
|
@ -305,11 +226,9 @@ pub async fn get_avatar_url_route(
|
|||
}
|
||||
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over federation
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Profile was not found.",
|
||||
));
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over
|
||||
// federation
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
|
||||
}
|
||||
|
||||
Ok(get_avatar_url::v3::Response {
|
||||
|
@ -324,9 +243,7 @@ pub async fn get_avatar_url_route(
|
|||
///
|
||||
/// - If user is on another server and we do not have a local copy already,
|
||||
/// fetch profile over federation.
|
||||
pub async fn get_profile_route(
|
||||
body: Ruma<get_profile::v3::Request>,
|
||||
) -> Result<get_profile::v3::Response> {
|
||||
pub async fn get_profile_route(body: Ruma<get_profile::v3::Request>) -> Result<get_profile::v3::Response> {
|
||||
if body.user_id.server_name() != services().globals.server_name() {
|
||||
// Create and update our local copy of the user
|
||||
if let Ok(response) = services()
|
||||
|
@ -344,18 +261,9 @@ pub async fn get_profile_route(
|
|||
services().users.create(&body.user_id, None)?;
|
||||
}
|
||||
|
||||
services()
|
||||
.users
|
||||
.set_displayname(&body.user_id, response.displayname.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_avatar_url(&body.user_id, response.avatar_url.clone())
|
||||
.await?;
|
||||
services()
|
||||
.users
|
||||
.set_blurhash(&body.user_id, response.blurhash.clone())
|
||||
.await?;
|
||||
services().users.set_displayname(&body.user_id, response.displayname.clone()).await?;
|
||||
services().users.set_avatar_url(&body.user_id, response.avatar_url.clone()).await?;
|
||||
services().users.set_blurhash(&body.user_id, response.blurhash.clone()).await?;
|
||||
|
||||
return Ok(get_profile::v3::Response {
|
||||
displayname: response.displayname,
|
||||
|
@ -366,11 +274,9 @@ pub async fn get_profile_route(
|
|||
}
|
||||
|
||||
if !services().users.exists(&body.user_id)? {
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over federation
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Profile was not found.",
|
||||
));
|
||||
// Return 404 if this user doesn't exist and we couldn't fetch it over
|
||||
// federation
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Profile was not found."));
|
||||
}
|
||||
|
||||
Ok(get_profile::v3::Response {
|
||||
|
|
|
@ -1,17 +1,17 @@
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
push::{
|
||||
delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled,
|
||||
get_pushrules_all, set_pusher, set_pushrule, set_pushrule_actions,
|
||||
set_pushrule_enabled, RuleScope,
|
||||
delete_pushrule, get_pushers, get_pushrule, get_pushrule_actions, get_pushrule_enabled, get_pushrules_all,
|
||||
set_pusher, set_pushrule, set_pushrule_actions, set_pushrule_enabled, RuleScope,
|
||||
},
|
||||
},
|
||||
events::{push_rules::PushRulesEvent, GlobalAccountDataEventType},
|
||||
push::{InsertPushRuleError, RemovePushRuleError},
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/pushrules`
|
||||
///
|
||||
/// Retrieves the push rules event for this user.
|
||||
|
@ -22,15 +22,8 @@ pub async fn get_pushrules_all_route(
|
|||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
|
@ -44,48 +37,33 @@ pub async fn get_pushrules_all_route(
|
|||
/// # `GET /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
||||
///
|
||||
/// Retrieves a single specified push rule for this user.
|
||||
pub async fn get_pushrule_route(
|
||||
body: Ruma<get_pushrule::v3::Request>,
|
||||
) -> Result<get_pushrule::v3::Response> {
|
||||
pub async fn get_pushrule_route(body: Ruma<get_pushrule::v3::Request>) -> Result<get_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
.content;
|
||||
|
||||
let rule = account_data
|
||||
.global
|
||||
.get(body.kind.clone(), &body.rule_id)
|
||||
.map(Into::into);
|
||||
let rule = account_data.global.get(body.kind.clone(), &body.rule_id).map(Into::into);
|
||||
|
||||
if let Some(rule) = rule {
|
||||
Ok(get_pushrule::v3::Response { rule })
|
||||
Ok(get_pushrule::v3::Response {
|
||||
rule,
|
||||
})
|
||||
} else {
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Push rule not found.",
|
||||
))
|
||||
Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))
|
||||
}
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
||||
///
|
||||
/// Creates a single specified push rule for this user.
|
||||
pub async fn set_pushrule_route(
|
||||
body: Ruma<set_pushrule::v3::Request>,
|
||||
) -> Result<set_pushrule::v3::Response> {
|
||||
pub async fn set_pushrule_route(body: Ruma<set_pushrule::v3::Request>) -> Result<set_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let body = body.body;
|
||||
|
||||
|
@ -98,41 +76,30 @@ pub async fn set_pushrule_route(
|
|||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
if let Err(error) = account_data.content.global.insert(
|
||||
body.rule.clone(),
|
||||
body.after.as_deref(),
|
||||
body.before.as_deref(),
|
||||
) {
|
||||
if let Err(error) =
|
||||
account_data.content.global.insert(body.rule.clone(), body.after.as_deref(), body.before.as_deref())
|
||||
{
|
||||
let err = match error {
|
||||
InsertPushRuleError::ServerDefaultRuleId => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Rule IDs starting with a dot are reserved for server-default rules.",
|
||||
),
|
||||
InsertPushRuleError::InvalidRuleId => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Rule ID containing invalid characters.",
|
||||
),
|
||||
InsertPushRuleError::InvalidRuleId => {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Rule ID containing invalid characters.")
|
||||
},
|
||||
InsertPushRuleError::RelativeToServerDefaultRule => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Can't place a push rule relatively to a server-default rule.",
|
||||
),
|
||||
InsertPushRuleError::UnknownRuleId => Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"The before or after rule could not be found.",
|
||||
),
|
||||
InsertPushRuleError::UnknownRuleId => {
|
||||
Error::BadRequest(ErrorKind::NotFound, "The before or after rule could not be found.")
|
||||
},
|
||||
InsertPushRuleError::BeforeHigherThanAfter => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"The before rule has a higher priority than the after rule.",
|
||||
|
@ -170,15 +137,8 @@ pub async fn get_pushrule_actions_route(
|
|||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?
|
||||
|
@ -188,12 +148,11 @@ pub async fn get_pushrule_actions_route(
|
|||
let actions = global
|
||||
.get(body.kind.clone(), &body.rule_id)
|
||||
.map(|rule| rule.actions().to_owned())
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Push rule not found.",
|
||||
))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?;
|
||||
|
||||
Ok(get_pushrule_actions::v3::Response { actions })
|
||||
Ok(get_pushrule_actions::v3::Response {
|
||||
actions,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/actions`
|
||||
|
@ -213,29 +172,14 @@ pub async fn set_pushrule_actions_route(
|
|||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
if account_data
|
||||
.content
|
||||
.global
|
||||
.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone())
|
||||
.is_err()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Push rule not found.",
|
||||
));
|
||||
if account_data.content.global.set_actions(body.kind.clone(), &body.rule_id, body.actions.clone()).is_err() {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
|
||||
}
|
||||
|
||||
services().account_data.update(
|
||||
|
@ -265,15 +209,8 @@ pub async fn get_pushrule_enabled_route(
|
|||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
@ -282,12 +219,11 @@ pub async fn get_pushrule_enabled_route(
|
|||
let enabled = global
|
||||
.get(body.kind.clone(), &body.rule_id)
|
||||
.map(ruma::push::AnyPushRuleRef::enabled)
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Push rule not found.",
|
||||
))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."))?;
|
||||
|
||||
Ok(get_pushrule_enabled::v3::Response { enabled })
|
||||
Ok(get_pushrule_enabled::v3::Response {
|
||||
enabled,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}/enabled`
|
||||
|
@ -307,29 +243,14 @@ pub async fn set_pushrule_enabled_route(
|
|||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
if account_data
|
||||
.content
|
||||
.global
|
||||
.set_enabled(body.kind.clone(), &body.rule_id, body.enabled)
|
||||
.is_err()
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Push rule not found.",
|
||||
));
|
||||
if account_data.content.global.set_enabled(body.kind.clone(), &body.rule_id, body.enabled).is_err() {
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Push rule not found."));
|
||||
}
|
||||
|
||||
services().account_data.update(
|
||||
|
@ -345,9 +266,7 @@ pub async fn set_pushrule_enabled_route(
|
|||
/// # `DELETE /_matrix/client/r0/pushrules/{scope}/{kind}/{ruleId}`
|
||||
///
|
||||
/// Deletes a single specified push rule for this user.
|
||||
pub async fn delete_pushrule_route(
|
||||
body: Ruma<delete_pushrule::v3::Request>,
|
||||
) -> Result<delete_pushrule::v3::Response> {
|
||||
pub async fn delete_pushrule_route(body: Ruma<delete_pushrule::v3::Request>) -> Result<delete_pushrule::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if body.scope != RuleScope::Global {
|
||||
|
@ -359,32 +278,18 @@ pub async fn delete_pushrule_route(
|
|||
|
||||
let event = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
sender_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PushRules event not found.",
|
||||
))?;
|
||||
.get(None, sender_user, GlobalAccountDataEventType::PushRules.to_string().into())?
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "PushRules event not found."))?;
|
||||
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(event.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))?;
|
||||
|
||||
if let Err(error) = account_data
|
||||
.content
|
||||
.global
|
||||
.remove(body.kind.clone(), &body.rule_id)
|
||||
{
|
||||
if let Err(error) = account_data.content.global.remove(body.kind.clone(), &body.rule_id) {
|
||||
let err = match error {
|
||||
RemovePushRuleError::ServerDefault => Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Cannot delete a server-default pushrule.",
|
||||
),
|
||||
RemovePushRuleError::NotFound => {
|
||||
Error::BadRequest(ErrorKind::NotFound, "Push rule not found.")
|
||||
}
|
||||
RemovePushRuleError::ServerDefault => {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Cannot delete a server-default pushrule.")
|
||||
},
|
||||
RemovePushRuleError::NotFound => Error::BadRequest(ErrorKind::NotFound, "Push rule not found."),
|
||||
_ => Error::BadRequest(ErrorKind::InvalidParam, "Invalid data."),
|
||||
};
|
||||
|
||||
|
@ -404,9 +309,7 @@ pub async fn delete_pushrule_route(
|
|||
/// # `GET /_matrix/client/r0/pushers`
|
||||
///
|
||||
/// Gets all currently active pushers for the sender user.
|
||||
pub async fn get_pushers_route(
|
||||
body: Ruma<get_pushers::v3::Request>,
|
||||
) -> Result<get_pushers::v3::Response> {
|
||||
pub async fn get_pushers_route(body: Ruma<get_pushers::v3::Request>) -> Result<get_pushers::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
Ok(get_pushers::v3::Response {
|
||||
|
@ -419,14 +322,10 @@ pub async fn get_pushers_route(
|
|||
/// Adds a pusher for the sender user.
|
||||
///
|
||||
/// - TODO: Handle `append`
|
||||
pub async fn set_pushers_route(
|
||||
body: Ruma<set_pusher::v3::Request>,
|
||||
) -> Result<set_pusher::v3::Response> {
|
||||
pub async fn set_pushers_route(body: Ruma<set_pusher::v3::Request>) -> Result<set_pusher::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
services()
|
||||
.pusher
|
||||
.set_pusher(sender_user, body.action.clone())?;
|
||||
services().pusher.set_pusher(sender_user, body.action.clone())?;
|
||||
|
||||
Ok(set_pusher::v3::Response::default())
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::{
|
||||
api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt},
|
||||
events::{
|
||||
|
@ -7,17 +8,17 @@ use ruma::{
|
|||
},
|
||||
MilliSecondsSinceUnixEpoch,
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::{service::rooms::timeline::PduCount, services, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers`
|
||||
///
|
||||
/// Sets different types of read markers.
|
||||
///
|
||||
/// - Updates fully-read account data event to `fully_read`
|
||||
/// - If `read_receipt` is set: Update private marker and public read receipt EDU
|
||||
pub async fn set_read_marker_route(
|
||||
body: Ruma<set_read_marker::v3::Request>,
|
||||
) -> Result<set_read_marker::v3::Response> {
|
||||
/// - If `read_receipt` is set: Update private marker and public read receipt
|
||||
/// EDU
|
||||
pub async fn set_read_marker_route(body: Ruma<set_read_marker::v3::Request>) -> Result<set_read_marker::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if let Some(fully_read) = &body.fully_read {
|
||||
|
@ -35,10 +36,7 @@ pub async fn set_read_marker_route(
|
|||
}
|
||||
|
||||
if body.private_read_receipt.is_some() || body.read_receipt.is_some() {
|
||||
services()
|
||||
.rooms
|
||||
.user
|
||||
.reset_notification_counts(sender_user, &body.room_id)?;
|
||||
services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?;
|
||||
}
|
||||
|
||||
if let Some(event) = &body.private_read_receipt {
|
||||
|
@ -46,24 +44,17 @@ pub async fn set_read_marker_route(
|
|||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(event)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Event does not exist.",
|
||||
))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
|
||||
let count = match count {
|
||||
PduCount::Backfilled(_) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Read receipt is in backfilled timeline",
|
||||
))
|
||||
}
|
||||
},
|
||||
PduCount::Normal(c) => c,
|
||||
};
|
||||
services()
|
||||
.rooms
|
||||
.edus
|
||||
.read_receipt
|
||||
.private_read_set(&body.room_id, sender_user, count)?;
|
||||
services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?;
|
||||
}
|
||||
|
||||
if let Some(event) = &body.read_receipt {
|
||||
|
@ -98,19 +89,14 @@ pub async fn set_read_marker_route(
|
|||
/// # `POST /_matrix/client/r0/rooms/{roomId}/receipt/{receiptType}/{eventId}`
|
||||
///
|
||||
/// Sets private read marker and public read receipt EDU.
|
||||
pub async fn create_receipt_route(
|
||||
body: Ruma<create_receipt::v3::Request>,
|
||||
) -> Result<create_receipt::v3::Response> {
|
||||
pub async fn create_receipt_route(body: Ruma<create_receipt::v3::Request>) -> Result<create_receipt::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if matches!(
|
||||
&body.receipt_type,
|
||||
create_receipt::v3::ReceiptType::Read | create_receipt::v3::ReceiptType::ReadPrivate
|
||||
) {
|
||||
services()
|
||||
.rooms
|
||||
.user
|
||||
.reset_notification_counts(sender_user, &body.room_id)?;
|
||||
services().rooms.user.reset_notification_counts(sender_user, &body.room_id)?;
|
||||
}
|
||||
|
||||
match body.receipt_type {
|
||||
|
@ -126,7 +112,7 @@ pub async fn create_receipt_route(
|
|||
RoomAccountDataEventType::FullyRead,
|
||||
&serde_json::to_value(fully_read_event).expect("to json value always works"),
|
||||
)?;
|
||||
}
|
||||
},
|
||||
create_receipt::v3::ReceiptType::Read => {
|
||||
let mut user_receipts = BTreeMap::new();
|
||||
user_receipts.insert(
|
||||
|
@ -150,31 +136,24 @@ pub async fn create_receipt_route(
|
|||
room_id: body.room_id.clone(),
|
||||
},
|
||||
)?;
|
||||
}
|
||||
},
|
||||
create_receipt::v3::ReceiptType::ReadPrivate => {
|
||||
let count = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_count(&body.event_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Event does not exist.",
|
||||
))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?;
|
||||
let count = match count {
|
||||
PduCount::Backfilled(_) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Read receipt is in backfilled timeline",
|
||||
))
|
||||
}
|
||||
},
|
||||
PduCount::Normal(c) => c,
|
||||
};
|
||||
services().rooms.edus.read_receipt.private_read_set(
|
||||
&body.room_id,
|
||||
sender_user,
|
||||
count,
|
||||
)?;
|
||||
}
|
||||
services().rooms.edus.read_receipt.private_read_set(&body.room_id, sender_user, count)?;
|
||||
},
|
||||
_ => return Err(Error::bad_database("Unsupported receipt type")),
|
||||
}
|
||||
|
||||
|
|
|
@ -1,33 +1,24 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::redact::redact_event,
|
||||
events::{room::redaction::RoomRedactionEventContent, TimelineEventType},
|
||||
};
|
||||
|
||||
use serde_json::value::to_raw_value;
|
||||
|
||||
use crate::{service::pdu::PduBuilder, services, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/redact/{eventId}/{txnId}`
|
||||
///
|
||||
/// Tries to send a redaction event into the room.
|
||||
///
|
||||
/// - TODO: Handle txn id
|
||||
pub async fn redact_event_route(
|
||||
body: Ruma<redact_event::v3::Request>,
|
||||
) -> Result<redact_event::v3::Response> {
|
||||
pub async fn redact_event_route(body: Ruma<redact_event::v3::Request>) -> Result<redact_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let body = body.body;
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(body.room_id.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let event_id = services()
|
||||
|
@ -54,5 +45,7 @@ pub async fn redact_event_route(
|
|||
drop(state_lock);
|
||||
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(redact_event::v3::Response { event_id })
|
||||
Ok(redact_event::v3::Response {
|
||||
event_id,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use ruma::api::client::relations::{
|
||||
get_relating_events, get_relating_events_with_rel_type,
|
||||
get_relating_events_with_rel_type_and_event_type,
|
||||
get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type,
|
||||
};
|
||||
|
||||
use crate::{service::rooms::timeline::PduCount, services, Result, Ruma};
|
||||
|
@ -20,22 +19,12 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
|
|||
},
|
||||
};
|
||||
|
||||
let to = body
|
||||
.to
|
||||
.as_ref()
|
||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
||||
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body
|
||||
.limit
|
||||
.and_then(|u| u32::try_from(u).ok())
|
||||
.map_or(10_usize, |u| u as usize)
|
||||
.min(100);
|
||||
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
|
||||
|
||||
let res = services()
|
||||
.rooms
|
||||
.pdu_metadata
|
||||
.paginate_relations_with_filter(
|
||||
let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_id,
|
||||
|
@ -46,13 +35,11 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
|
|||
limit,
|
||||
)?;
|
||||
|
||||
Ok(
|
||||
get_relating_events_with_rel_type_and_event_type::v1::Response {
|
||||
Ok(get_relating_events_with_rel_type_and_event_type::v1::Response {
|
||||
chunk: res.chunk,
|
||||
next_batch: res.next_batch,
|
||||
prev_batch: res.prev_batch,
|
||||
},
|
||||
)
|
||||
})
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}`
|
||||
|
@ -70,22 +57,12 @@ pub async fn get_relating_events_with_rel_type_route(
|
|||
},
|
||||
};
|
||||
|
||||
let to = body
|
||||
.to
|
||||
.as_ref()
|
||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
||||
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body
|
||||
.limit
|
||||
.and_then(|u| u32::try_from(u).ok())
|
||||
.map_or(10_usize, |u| u as usize)
|
||||
.min(100);
|
||||
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
|
||||
|
||||
let res = services()
|
||||
.rooms
|
||||
.pdu_metadata
|
||||
.paginate_relations_with_filter(
|
||||
let res = services().rooms.pdu_metadata.paginate_relations_with_filter(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_id,
|
||||
|
@ -118,22 +95,12 @@ pub async fn get_relating_events_route(
|
|||
},
|
||||
};
|
||||
|
||||
let to = body
|
||||
.to
|
||||
.as_ref()
|
||||
.and_then(|t| PduCount::try_from_string(t).ok());
|
||||
let to = body.to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
|
||||
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body
|
||||
.limit
|
||||
.and_then(|u| u32::try_from(u).ok())
|
||||
.map_or(10_usize, |u| u as usize)
|
||||
.min(100);
|
||||
let limit = body.limit.and_then(|u| u32::try_from(u).ok()).map_or(10_usize, |u| u as usize).min(100);
|
||||
|
||||
services()
|
||||
.rooms
|
||||
.pdu_metadata
|
||||
.paginate_relations_with_filter(
|
||||
services().rooms.pdu_metadata.paginate_relations_with_filter(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
&body.event_id,
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use crate::{services, utils::HtmlEscape, Error, Result, Ruma};
|
||||
use rand::Rng;
|
||||
use ruma::{
|
||||
api::client::{error::ErrorKind, room::report_content},
|
||||
|
@ -10,13 +9,12 @@ use ruma::{
|
|||
use tokio::time::sleep;
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{services, utils::HtmlEscape, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/v3/rooms/{roomId}/report/{eventId}`
|
||||
///
|
||||
/// Reports an inappropriate event to homeserver admins
|
||||
///
|
||||
pub async fn report_event_route(
|
||||
body: Ruma<report_content::v3::Request>,
|
||||
) -> Result<report_content::v3::Response> {
|
||||
pub async fn report_event_route(body: Ruma<report_content::v3::Request>) -> Result<report_content::v3::Response> {
|
||||
// user authentication
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
|
@ -30,7 +28,7 @@ pub async fn report_event_route(
|
|||
ErrorKind::NotFound,
|
||||
"Event ID is not known to us or Event ID is invalid",
|
||||
))
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// check if the room ID from the URI matches the PDU's room ID
|
||||
|
@ -71,17 +69,12 @@ pub async fn report_event_route(
|
|||
));
|
||||
};
|
||||
|
||||
// send admin room message that we received the report with an @room ping for urgency
|
||||
services()
|
||||
.admin
|
||||
.send_message(message::RoomMessageEventContent::text_html(
|
||||
// send admin room message that we received the report with an @room ping for
|
||||
// urgency
|
||||
services().admin.send_message(message::RoomMessageEventContent::text_html(
|
||||
format!(
|
||||
"@room Report received from: {}\n\n\
|
||||
Event ID: {}\n\
|
||||
Room ID: {}\n\
|
||||
Sent By: {}\n\n\
|
||||
Report Score: {}\n\
|
||||
Report Reason: {}",
|
||||
"@room Report received from: {}\n\nEvent ID: {}\nRoom ID: {}\nSent By: {}\n\nReport Score: {}\nReport \
|
||||
Reason: {}",
|
||||
sender_user.to_owned(),
|
||||
pdu.event_id,
|
||||
pdu.room_id,
|
||||
|
@ -105,8 +98,9 @@ pub async fn report_event_route(
|
|||
),
|
||||
));
|
||||
|
||||
// even though this is kinda security by obscurity, let's still make a small random delay sending a successful response
|
||||
// per spec suggestion regarding enumerating for potential events existing in our server.
|
||||
// even though this is kinda security by obscurity, let's still make a small
|
||||
// random delay sending a successful response per spec suggestion regarding
|
||||
// enumerating for potential events existing in our server.
|
||||
let time_to_wait = rand::thread_rng().gen_range(8..21);
|
||||
debug!(
|
||||
"Got successful /report request, waiting {} seconds before sending successful response.",
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use crate::{
|
||||
api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma,
|
||||
};
|
||||
use std::{cmp::max, collections::BTreeMap, sync::Arc};
|
||||
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
|
@ -23,13 +22,13 @@ use ruma::{
|
|||
},
|
||||
int,
|
||||
serde::JsonObject,
|
||||
CanonicalJsonObject, CanonicalJsonValue, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId,
|
||||
RoomVersionId,
|
||||
CanonicalJsonObject, CanonicalJsonValue, OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId, RoomVersionId,
|
||||
};
|
||||
use serde_json::{json, value::to_raw_value};
|
||||
use std::{cmp::max, collections::BTreeMap, sync::Arc};
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{api::client_server::invite_helper, service::pdu::PduBuilder, services, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/v3/createRoom`
|
||||
///
|
||||
/// Creates a new room.
|
||||
|
@ -46,27 +45,19 @@ use tracing::{debug, error, info, warn};
|
|||
/// - Send events listed in initial state
|
||||
/// - Send events implied by `name` and `topic`
|
||||
/// - Send invite events
|
||||
pub async fn create_room_route(
|
||||
body: Ruma<create_room::v3::Request>,
|
||||
) -> Result<create_room::v3::Response> {
|
||||
pub async fn create_room_route(body: Ruma<create_room::v3::Request>) -> Result<create_room::v3::Response> {
|
||||
use create_room::v3::RoomPreset;
|
||||
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services().globals.allow_room_creation()
|
||||
&& !&body.from_appservice
|
||||
&& !services().users.is_admin(sender_user)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Room creation has been disabled.",
|
||||
));
|
||||
if !services().globals.allow_room_creation() && !&body.from_appservice && !services().users.is_admin(sender_user)? {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Room creation has been disabled."));
|
||||
}
|
||||
|
||||
let room_id: OwnedRoomId;
|
||||
|
||||
// checks if the user specified an explicit (custom) room_id to be created with in request body.
|
||||
// falls back to normal generated room ID if not specified.
|
||||
// checks if the user specified an explicit (custom) room_id to be created with
|
||||
// in request body. falls back to normal generated room ID if not specified.
|
||||
if let Some(CanonicalJsonValue::Object(json_body)) = &body.json_body {
|
||||
match json_body.get("room_id") {
|
||||
Some(custom_room_id) => {
|
||||
|
@ -76,7 +67,8 @@ pub async fn create_room_route(
|
|||
if custom_room_id_s.contains(':') {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Custom room ID contained `:` which is not allowed. Please note that this expects a localpart, not the full room ID.",
|
||||
"Custom room ID contained `:` which is not allowed. Please note that this expects a \
|
||||
localpart, not the full room ID.",
|
||||
));
|
||||
} else if custom_room_id_s.contains(char::is_whitespace) {
|
||||
return Err(Error::BadRequest(
|
||||
|
@ -84,41 +76,24 @@ pub async fn create_room_route(
|
|||
"Custom room ID contained spaces which is not valid.",
|
||||
));
|
||||
} else if custom_room_id_s.len() > 255 {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Custom room ID is too long.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Custom room ID is too long."));
|
||||
}
|
||||
|
||||
// apply forbidden room alias checks to custom room IDs too
|
||||
if services()
|
||||
.globals
|
||||
.forbidden_room_names()
|
||||
.is_match(&custom_room_id_s)
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Custom room ID is forbidden.",
|
||||
));
|
||||
if services().globals.forbidden_room_names().is_match(&custom_room_id_s) {
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Custom room ID is forbidden."));
|
||||
}
|
||||
|
||||
let full_room_id = "!".to_owned()
|
||||
+ &custom_room_id_s.replace('"', "")
|
||||
+ ":"
|
||||
+ services().globals.server_name().as_ref();
|
||||
+ ":" + services().globals.server_name().as_ref();
|
||||
debug!("Full room ID: {}", full_room_id);
|
||||
|
||||
room_id = RoomId::parse(full_room_id).map_err(|e| {
|
||||
info!(
|
||||
"User attempted to create room with custom room ID but failed parsing: {}",
|
||||
e
|
||||
);
|
||||
Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Custom room ID could not be parsed",
|
||||
)
|
||||
info!("User attempted to create room with custom room ID but failed parsing: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Custom room ID could not be parsed")
|
||||
})?;
|
||||
}
|
||||
},
|
||||
None => room_id = RoomId::new(services().globals.server_name()),
|
||||
}
|
||||
} else {
|
||||
|
@ -135,27 +110,17 @@ pub async fn create_room_route(
|
|||
|
||||
services().rooms.short.get_or_create_shortroomid(&room_id)?;
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(room_id.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let alias: Option<OwnedRoomAliasId> =
|
||||
body.room_alias_name
|
||||
.as_ref()
|
||||
.map_or(Ok(None), |localpart| {
|
||||
|
||||
let alias: Option<OwnedRoomAliasId> = body.room_alias_name.as_ref().map_or(Ok(None), |localpart| {
|
||||
// Basic checks on the room alias validity
|
||||
if localpart.contains(':') {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Room alias contained `:` which is not allowed. Please note that this expects a localpart, not the full room alias.",
|
||||
"Room alias contained `:` which is not allowed. Please note that this expects a localpart, not the \
|
||||
full room alias.",
|
||||
));
|
||||
} else if localpart.contains(char::is_whitespace) {
|
||||
return Err(Error::BadRequest(
|
||||
|
@ -164,9 +129,9 @@ pub async fn create_room_route(
|
|||
));
|
||||
} else if localpart.len() > 255 {
|
||||
// there is nothing spec-wise saying to check the limit of this,
|
||||
// however absurdly long room aliases are guaranteed to be unreadable or done maliciously.
|
||||
// there is no reason a room alias should even exceed 100 characters as is.
|
||||
// generally in spec, 255 is matrix's fav number
|
||||
// however absurdly long room aliases are guaranteed to be unreadable or done
|
||||
// maliciously. there is no reason a room alias should even exceed 100
|
||||
// characters as is. generally in spec, 255 is matrix's fav number
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Room alias is excessively long, clients may not be able to handle this. Please shorten it.",
|
||||
|
@ -179,37 +144,18 @@ pub async fn create_room_route(
|
|||
}
|
||||
|
||||
// check if room alias is forbidden
|
||||
if services()
|
||||
.globals
|
||||
.forbidden_room_names()
|
||||
.is_match(localpart)
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Room alias name is forbidden.",
|
||||
));
|
||||
if services().globals.forbidden_room_names().is_match(localpart) {
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Room alias name is forbidden."));
|
||||
}
|
||||
|
||||
let alias = RoomAliasId::parse(format!(
|
||||
"#{}:{}",
|
||||
localpart,
|
||||
services().globals.server_name()
|
||||
))
|
||||
.map_err(|e| {
|
||||
let alias =
|
||||
RoomAliasId::parse(format!("#{}:{}", localpart, services().globals.server_name())).map_err(|e| {
|
||||
warn!("Failed to parse room alias for room ID {}: {e}", room_id);
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid room alias specified.")
|
||||
})?;
|
||||
|
||||
if services()
|
||||
.rooms
|
||||
.alias
|
||||
.resolve_local_alias(&alias)?
|
||||
.is_some()
|
||||
{
|
||||
Err(Error::BadRequest(
|
||||
ErrorKind::RoomInUse,
|
||||
"Room alias already exists.",
|
||||
))
|
||||
if services().rooms.alias.resolve_local_alias(&alias)?.is_some() {
|
||||
Err(Error::BadRequest(ErrorKind::RoomInUse, "Room alias already exists."))
|
||||
} else {
|
||||
Ok(Some(alias))
|
||||
}
|
||||
|
@ -217,11 +163,7 @@ pub async fn create_room_route(
|
|||
|
||||
let room_version = match body.room_version.clone() {
|
||||
Some(room_version) => {
|
||||
if services()
|
||||
.globals
|
||||
.supported_room_versions()
|
||||
.contains(&room_version)
|
||||
{
|
||||
if services().globals.supported_room_versions().contains(&room_version) {
|
||||
room_version
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
|
@ -229,15 +171,13 @@ pub async fn create_room_route(
|
|||
"This server does not support that room version.",
|
||||
));
|
||||
}
|
||||
}
|
||||
},
|
||||
None => services().globals.default_room_version(),
|
||||
};
|
||||
|
||||
let content = match &body.creation_content {
|
||||
Some(content) => {
|
||||
let mut content = content
|
||||
.deserialize_as::<CanonicalJsonObject>()
|
||||
.map_err(|e| {
|
||||
let mut content = content.deserialize_as::<CanonicalJsonObject>().map_err(|e| {
|
||||
error!("Failed to deserialise content as canonical JSON: {}", e);
|
||||
Error::bad_database("Failed to deserialise content as canonical JSON.")
|
||||
})?;
|
||||
|
@ -259,25 +199,25 @@ pub async fn create_room_route(
|
|||
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")
|
||||
})?,
|
||||
);
|
||||
}
|
||||
RoomVersionId::V11 => {} // V11 removed the "creator" key
|
||||
},
|
||||
RoomVersionId::V11 => {}, // V11 removed the "creator" key
|
||||
_ => {
|
||||
warn!("Unexpected or unsupported room version {}", room_version);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::BadJson,
|
||||
"Unexpected or unsupported room version found",
|
||||
));
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
content.insert(
|
||||
"room_version".into(),
|
||||
json!(room_version.as_str()).try_into().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")
|
||||
})?,
|
||||
json!(room_version.as_str())
|
||||
.try_into()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))?,
|
||||
);
|
||||
content
|
||||
}
|
||||
},
|
||||
None => {
|
||||
// TODO: Add correct value for v11
|
||||
let content = match room_version {
|
||||
|
@ -298,7 +238,7 @@ pub async fn create_room_route(
|
|||
ErrorKind::BadJson,
|
||||
"Unexpected or unsupported room version found",
|
||||
));
|
||||
}
|
||||
},
|
||||
};
|
||||
let mut content = serde_json::from_str::<CanonicalJsonObject>(
|
||||
to_raw_value(&content)
|
||||
|
@ -308,26 +248,20 @@ pub async fn create_room_route(
|
|||
.unwrap();
|
||||
content.insert(
|
||||
"room_version".into(),
|
||||
json!(room_version.as_str()).try_into().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::BadJson, "Invalid creation content")
|
||||
})?,
|
||||
json!(room_version.as_str())
|
||||
.try_into()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"))?,
|
||||
);
|
||||
content
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
// Validate creation content
|
||||
let de_result = serde_json::from_str::<CanonicalJsonObject>(
|
||||
to_raw_value(&content)
|
||||
.expect("Invalid creation content")
|
||||
.get(),
|
||||
);
|
||||
let de_result =
|
||||
serde_json::from_str::<CanonicalJsonObject>(to_raw_value(&content).expect("Invalid creation content").get());
|
||||
|
||||
if de_result.is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::BadJson,
|
||||
"Invalid creation content",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::BadJson, "Invalid creation content"));
|
||||
}
|
||||
|
||||
// 1. The room create event
|
||||
|
@ -401,9 +335,7 @@ pub async fn create_room_route(
|
|||
|
||||
if let Some(power_level_content_override) = &body.power_level_content_override {
|
||||
let json: JsonObject = serde_json::from_str(power_level_content_override.json().get())
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::BadJson, "Invalid power_level_content_override.")
|
||||
})?;
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::BadJson, "Invalid power_level_content_override."))?;
|
||||
|
||||
for (key, value) in json {
|
||||
power_levels_content[key] = value;
|
||||
|
@ -416,8 +348,7 @@ pub async fn create_room_route(
|
|||
.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: TimelineEventType::RoomPowerLevels,
|
||||
content: to_raw_value(&power_levels_content)
|
||||
.expect("to_raw_value always works on serde_json::Value"),
|
||||
content: to_raw_value(&power_levels_content).expect("to_raw_value always works on serde_json::Value"),
|
||||
unsigned: None,
|
||||
state_key: Some("".to_owned()),
|
||||
redacts: None,
|
||||
|
@ -484,9 +415,7 @@ pub async fn create_room_route(
|
|||
.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: TimelineEventType::RoomHistoryVisibility,
|
||||
content: to_raw_value(&RoomHistoryVisibilityEventContent::new(
|
||||
HistoryVisibility::Shared,
|
||||
))
|
||||
content: to_raw_value(&RoomHistoryVisibilityEventContent::new(HistoryVisibility::Shared))
|
||||
.expect("event is valid, we just created it"),
|
||||
unsigned: None,
|
||||
state_key: Some("".to_owned()),
|
||||
|
@ -531,17 +460,11 @@ pub async fn create_room_route(
|
|||
pdu_builder.state_key.get_or_insert_with(|| "".to_owned());
|
||||
|
||||
// Silently skip encryption events if they are not allowed
|
||||
if pdu_builder.event_type == TimelineEventType::RoomEncryption
|
||||
&& !services().globals.allow_encryption()
|
||||
{
|
||||
if pdu_builder.event_type == TimelineEventType::RoomEncryption && !services().globals.allow_encryption() {
|
||||
continue;
|
||||
}
|
||||
|
||||
services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock)
|
||||
.await?;
|
||||
services().rooms.timeline.build_and_append_pdu(pdu_builder, sender_user, &room_id, &state_lock).await?;
|
||||
}
|
||||
|
||||
// 7. Events implied by name and topic
|
||||
|
@ -611,26 +534,17 @@ pub async fn create_room_route(
|
|||
///
|
||||
/// Gets a single event.
|
||||
///
|
||||
/// - You have to currently be joined to the room (TODO: Respect history visibility)
|
||||
pub async fn get_room_event_route(
|
||||
body: Ruma<get_room_event::v3::Request>,
|
||||
) -> Result<get_room_event::v3::Response> {
|
||||
/// - You have to currently be joined to the room (TODO: Respect history
|
||||
/// visibility)
|
||||
pub async fn get_room_event_route(body: Ruma<get_room_event::v3::Request>) -> Result<get_room_event::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu(&body.event_id)?
|
||||
.ok_or_else(|| {
|
||||
let event = services().rooms.timeline.get_pdu(&body.event_id)?.ok_or_else(|| {
|
||||
warn!("Event not found, event ID: {:?}", &body.event_id);
|
||||
Error::BadRequest(ErrorKind::NotFound, "Event not found.")
|
||||
})?;
|
||||
|
||||
if !services().rooms.state_accessor.user_can_see_event(
|
||||
sender_user,
|
||||
&event.room_id,
|
||||
&body.event_id,
|
||||
)? {
|
||||
if !services().rooms.state_accessor.user_can_see_event(sender_user, &event.room_id, &body.event_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this event.",
|
||||
|
@ -649,17 +563,12 @@ pub async fn get_room_event_route(
|
|||
///
|
||||
/// Lists all aliases of the room.
|
||||
///
|
||||
/// - Only users joined to the room are allowed to call this TODO: Allow any user to call it if history_visibility is world readable
|
||||
pub async fn get_room_aliases_route(
|
||||
body: Ruma<aliases::v3::Request>,
|
||||
) -> Result<aliases::v3::Response> {
|
||||
/// - Only users joined to the room are allowed to call this TODO: Allow any
|
||||
/// user to call it if history_visibility is world readable
|
||||
pub async fn get_room_aliases_route(body: Ruma<aliases::v3::Request>) -> Result<aliases::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &body.room_id)?
|
||||
{
|
||||
if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this room.",
|
||||
|
@ -686,16 +595,10 @@ pub async fn get_room_aliases_route(
|
|||
/// - Transfers some state events
|
||||
/// - Moves local aliases
|
||||
/// - Modifies old room power levels to prevent users from speaking
|
||||
pub async fn upgrade_room_route(
|
||||
body: Ruma<upgrade_room::v3::Request>,
|
||||
) -> Result<upgrade_room::v3::Response> {
|
||||
pub async fn upgrade_room_route(body: Ruma<upgrade_room::v3::Request>) -> Result<upgrade_room::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services()
|
||||
.globals
|
||||
.supported_room_versions()
|
||||
.contains(&body.new_version)
|
||||
{
|
||||
if !services().globals.supported_room_versions().contains(&body.new_version) {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnsupportedRoomVersion,
|
||||
"This server does not support that room version.",
|
||||
|
@ -704,24 +607,15 @@ pub async fn upgrade_room_route(
|
|||
|
||||
// Create a replacement room
|
||||
let replacement_room = RoomId::new(services().globals.server_name());
|
||||
services()
|
||||
.rooms
|
||||
.short
|
||||
.get_or_create_shortroomid(&replacement_room)?;
|
||||
services().rooms.short.get_or_create_shortroomid(&replacement_room)?;
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(body.room_id.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(body.room_id.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
// Send a m.room.tombstone event to the old room to indicate that it is not intended to be used any further
|
||||
// Fail if the sender does not have the required permissions
|
||||
// Send a m.room.tombstone event to the old room to indicate that it is not
|
||||
// intended to be used any further Fail if the sender does not have the required
|
||||
// permissions
|
||||
let tombstone_event_id = services()
|
||||
.rooms
|
||||
.timeline
|
||||
|
@ -745,15 +639,8 @@ pub async fn upgrade_room_route(
|
|||
|
||||
// Change lock to replacement room
|
||||
drop(state_lock);
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(replacement_room.clone())
|
||||
.or_default(),
|
||||
);
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(replacement_room.clone()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
// Get the old room creation event
|
||||
|
@ -774,7 +661,8 @@ pub async fn upgrade_room_route(
|
|||
(*tombstone_event_id).to_owned(),
|
||||
));
|
||||
|
||||
// Send a m.room.create event containing a predecessor field and the applicable room_version
|
||||
// Send a m.room.create event containing a predecessor field and the applicable
|
||||
// room_version
|
||||
match body.new_version {
|
||||
RoomVersionId::V1
|
||||
| RoomVersionId::V2
|
||||
|
@ -793,21 +681,18 @@ pub async fn upgrade_room_route(
|
|||
Error::BadRequest(ErrorKind::BadJson, "Error forming creation event")
|
||||
})?,
|
||||
);
|
||||
}
|
||||
},
|
||||
RoomVersionId::V11 => {
|
||||
// "creator" key no longer exists in V11 rooms
|
||||
create_event_content.remove("creator");
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
warn!(
|
||||
"Unexpected or unsupported room version {}",
|
||||
body.new_version
|
||||
);
|
||||
warn!("Unexpected or unsupported room version {}", body.new_version);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::BadJson,
|
||||
"Unexpected or unsupported room version found",
|
||||
));
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
create_event_content.insert(
|
||||
|
@ -825,16 +710,11 @@ pub async fn upgrade_room_route(
|
|||
|
||||
// Validate creation event content
|
||||
let de_result = serde_json::from_str::<CanonicalJsonObject>(
|
||||
to_raw_value(&create_event_content)
|
||||
.expect("Error forming creation event")
|
||||
.get(),
|
||||
to_raw_value(&create_event_content).expect("Error forming creation event").get(),
|
||||
);
|
||||
|
||||
if de_result.is_err() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::BadJson,
|
||||
"Error forming creation event",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::BadJson, "Error forming creation event"));
|
||||
}
|
||||
|
||||
services()
|
||||
|
@ -843,8 +723,7 @@ pub async fn upgrade_room_route(
|
|||
.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: TimelineEventType::RoomCreate,
|
||||
content: to_raw_value(&create_event_content)
|
||||
.expect("event is valid, we just created it"),
|
||||
content: to_raw_value(&create_event_content).expect("event is valid, we just created it"),
|
||||
unsigned: None,
|
||||
state_key: Some("".to_owned()),
|
||||
redacts: None,
|
||||
|
@ -898,12 +777,7 @@ pub async fn upgrade_room_route(
|
|||
|
||||
// Replicate transferable state events to the new room
|
||||
for event_type in transferable_state_events {
|
||||
let event_content =
|
||||
match services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &event_type, "")?
|
||||
{
|
||||
let event_content = match services().rooms.state_accessor.room_state_get(&body.room_id, &event_type, "")? {
|
||||
Some(v) => v.content.clone(),
|
||||
None => continue, // Skipping missing events.
|
||||
};
|
||||
|
@ -927,16 +801,8 @@ pub async fn upgrade_room_route(
|
|||
}
|
||||
|
||||
// Moves any local aliases to the new room
|
||||
for alias in services()
|
||||
.rooms
|
||||
.alias
|
||||
.local_aliases_for_room(&body.room_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
services()
|
||||
.rooms
|
||||
.alias
|
||||
.set_alias(&alias, &replacement_room)?;
|
||||
for alias in services().rooms.alias.local_aliases_for_room(&body.room_id).filter_map(std::result::Result::ok) {
|
||||
services().rooms.alias.set_alias(&alias, &replacement_room)?;
|
||||
}
|
||||
|
||||
// Get the old room power levels
|
||||
|
@ -956,15 +822,15 @@ pub async fn upgrade_room_route(
|
|||
power_levels_event_content.events_default = new_level;
|
||||
power_levels_event_content.invite = new_level;
|
||||
|
||||
// Modify the power levels in the old room to prevent sending of events and inviting new users
|
||||
// Modify the power levels in the old room to prevent sending of events and
|
||||
// inviting new users
|
||||
let _ = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.build_and_append_pdu(
|
||||
PduBuilder {
|
||||
event_type: TimelineEventType::RoomPowerLevels,
|
||||
content: to_raw_value(&power_levels_event_content)
|
||||
.expect("event is valid, we just created it"),
|
||||
content: to_raw_value(&power_levels_event_content).expect("event is valid, we just created it"),
|
||||
unsigned: None,
|
||||
state_key: Some("".to_owned()),
|
||||
redacts: None,
|
||||
|
@ -978,5 +844,7 @@ pub async fn upgrade_room_route(
|
|||
drop(state_lock);
|
||||
|
||||
// Return the replacement room id
|
||||
Ok(upgrade_room::v3::Response { replacement_room })
|
||||
Ok(upgrade_room::v3::Response {
|
||||
replacement_room,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::api::client::{
|
||||
error::ErrorKind,
|
||||
search::search_events::{
|
||||
|
@ -7,28 +8,22 @@ use ruma::api::client::{
|
|||
},
|
||||
};
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/r0/search`
|
||||
///
|
||||
/// Searches rooms for messages.
|
||||
///
|
||||
/// - Only works if the user is currently joined to the room (TODO: Respect history visibility)
|
||||
pub async fn search_events_route(
|
||||
body: Ruma<search_events::v3::Request>,
|
||||
) -> Result<search_events::v3::Response> {
|
||||
/// - Only works if the user is currently joined to the room (TODO: Respect
|
||||
/// history visibility)
|
||||
pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Result<search_events::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let search_criteria = body.search_categories.room_events.as_ref().unwrap();
|
||||
let filter = &search_criteria.filter;
|
||||
|
||||
let room_ids = filter.rooms.clone().unwrap_or_else(|| {
|
||||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(sender_user)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.collect()
|
||||
services().rooms.state_cache.rooms_joined(sender_user).filter_map(std::result::Result::ok).collect()
|
||||
});
|
||||
|
||||
// Use limit or else 10, with maximum 100
|
||||
|
@ -37,34 +32,21 @@ pub async fn search_events_route(
|
|||
let mut searches = Vec::new();
|
||||
|
||||
for room_id in room_ids {
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &room_id)?
|
||||
{
|
||||
if !services().rooms.state_cache.is_joined(sender_user, &room_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view this room.",
|
||||
));
|
||||
}
|
||||
|
||||
if let Some(search) = services()
|
||||
.rooms
|
||||
.search
|
||||
.search_pdus(&room_id, &search_criteria.search_term)?
|
||||
{
|
||||
if let Some(search) = services().rooms.search.search_pdus(&room_id, &search_criteria.search_term)? {
|
||||
searches.push(search.0.peekable());
|
||||
}
|
||||
}
|
||||
|
||||
let skip = match body.next_batch.as_ref().map(|s| s.parse()) {
|
||||
Some(Ok(s)) => s,
|
||||
Some(Err(_)) => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Invalid next_batch token.",
|
||||
))
|
||||
}
|
||||
Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")),
|
||||
None => 0, // Default to the start
|
||||
};
|
||||
|
||||
|
|
|
@ -1,5 +1,3 @@
|
|||
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use argon2::{PasswordHash, PasswordVerifier};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
|
@ -22,6 +20,9 @@ use ruma::{
|
|||
use serde::Deserialize;
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use super::{DEVICE_ID_LENGTH, TOKEN_LENGTH};
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
|
||||
#[derive(Debug, Deserialize)]
|
||||
struct Claims {
|
||||
sub: String,
|
||||
|
@ -30,11 +31,9 @@ struct Claims {
|
|||
|
||||
/// # `GET /_matrix/client/v3/login`
|
||||
///
|
||||
/// Get the supported login types of this server. One of these should be used as the `type` field
|
||||
/// when logging in.
|
||||
pub async fn get_login_types_route(
|
||||
_body: Ruma<get_login_types::v3::Request>,
|
||||
) -> Result<get_login_types::v3::Response> {
|
||||
/// Get the supported login types of this server. One of these should be used as
|
||||
/// the `type` field when logging in.
|
||||
pub async fn get_login_types_route(_body: Ruma<get_login_types::v3::Request>) -> Result<get_login_types::v3::Response> {
|
||||
Ok(get_login_types::v3::Response::new(vec to see
|
||||
/// Note: You can use [`GET
|
||||
/// /_matrix/client/r0/login`](fn.get_supported_versions_route.html) to see
|
||||
/// supported login types.
|
||||
pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Response> {
|
||||
// Validate login method
|
||||
|
@ -68,16 +70,19 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
|||
debug!("Using username from identifier field");
|
||||
user_id.to_lowercase()
|
||||
} else if let Some(user_id) = user {
|
||||
warn!("User \"{}\" is attempting to login with the deprecated \"user\" field at \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is destined to be removed in a future Matrix release.", user_id);
|
||||
warn!(
|
||||
"User \"{}\" is attempting to login with the deprecated \"user\" field at \
|
||||
\"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is \
|
||||
destined to be removed in a future Matrix release.",
|
||||
user_id
|
||||
);
|
||||
user_id.to_lowercase()
|
||||
} else {
|
||||
warn!("Bad login type: {:?}", &body.login_info);
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
||||
};
|
||||
|
||||
let user_id =
|
||||
UserId::parse_with_server_name(username, services().globals.server_name())
|
||||
.map_err(|e| {
|
||||
let user_id = UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|
||||
warn!("Failed to parse username from user logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
})?;
|
||||
|
@ -85,16 +90,10 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
|||
let hash = services()
|
||||
.users
|
||||
.password_hash(&user_id)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Wrong username or password.",
|
||||
))?;
|
||||
.ok_or(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password."))?;
|
||||
|
||||
if hash.is_empty() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UserDeactivated,
|
||||
"The user has been deactivated",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::UserDeactivated, "The user has been deactivated"));
|
||||
}
|
||||
|
||||
let Ok(parsed_hash) = PasswordHash::new(&hash) else {
|
||||
|
@ -102,29 +101,21 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
|||
return Err(Error::BadServerResponse("could not hash"));
|
||||
};
|
||||
|
||||
let hash_matches = services()
|
||||
.globals
|
||||
.argon
|
||||
.verify_password(password.as_bytes(), &parsed_hash)
|
||||
.is_ok();
|
||||
let hash_matches = services().globals.argon.verify_password(password.as_bytes(), &parsed_hash).is_ok();
|
||||
|
||||
if !hash_matches {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Wrong username or password.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Wrong username or password."));
|
||||
}
|
||||
|
||||
user_id
|
||||
}
|
||||
login::v3::LoginInfo::Token(login::v3::Token { token }) => {
|
||||
},
|
||||
login::v3::LoginInfo::Token(login::v3::Token {
|
||||
token,
|
||||
}) => {
|
||||
debug!("Got token login type");
|
||||
if let Some(jwt_decoding_key) = services().globals.jwt_decoding_key() {
|
||||
let token = jsonwebtoken::decode::<Claims>(
|
||||
token,
|
||||
jwt_decoding_key,
|
||||
&jsonwebtoken::Validation::default(),
|
||||
)
|
||||
let token =
|
||||
jsonwebtoken::decode::<Claims>(token, jwt_decoding_key, &jsonwebtoken::Validation::default())
|
||||
.map_err(|e| {
|
||||
warn!("Failed to parse JWT token from user logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Token is invalid.")
|
||||
|
@ -132,19 +123,17 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
|||
|
||||
let username = token.claims.sub.to_lowercase();
|
||||
|
||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|
||||
|e| {
|
||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|
||||
warn!("Failed to parse username from user logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
},
|
||||
)?
|
||||
})?
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Token login is not supported (server has no jwt decoding key).",
|
||||
));
|
||||
}
|
||||
}
|
||||
},
|
||||
#[allow(deprecated)]
|
||||
login::v3::LoginInfo::ApplicationService(login::v3::ApplicationService {
|
||||
identifier,
|
||||
|
@ -152,79 +141,65 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
|||
}) => {
|
||||
debug!("Got appservice login type");
|
||||
if !body.from_appservice {
|
||||
info!("User tried logging in as an appservice, but request body is not from a known/registered appservice");
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Forbidden login type.",
|
||||
));
|
||||
info!(
|
||||
"User tried logging in as an appservice, but request body is not from a known/registered \
|
||||
appservice"
|
||||
);
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Forbidden login type."));
|
||||
};
|
||||
let username = if let Some(UserIdentifier::UserIdOrLocalpart(user_id)) = identifier {
|
||||
user_id.to_lowercase()
|
||||
} else if let Some(user_id) = user {
|
||||
warn!("Appservice \"{}\" is attempting to login with the deprecated \"user\" field at \"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is destined to be removed in a future Matrix release.", user_id);
|
||||
warn!(
|
||||
"Appservice \"{}\" is attempting to login with the deprecated \"user\" field at \
|
||||
\"/_matrix/client/v3/login\". conduwuit implements this deprecated behaviour, but this is \
|
||||
destined to be removed in a future Matrix release.",
|
||||
user_id
|
||||
);
|
||||
user_id.to_lowercase()
|
||||
} else {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Bad login type."));
|
||||
};
|
||||
|
||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(
|
||||
|e| {
|
||||
UserId::parse_with_server_name(username, services().globals.server_name()).map_err(|e| {
|
||||
warn!("Failed to parse username from appservice logging in: {}", e);
|
||||
Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid.")
|
||||
})?
|
||||
},
|
||||
)?
|
||||
}
|
||||
_ => {
|
||||
warn!("Unsupported or unknown login type: {:?}", &body.login_info);
|
||||
debug!("JSON body: {:?}", &body.json_body);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Unsupported or unknown login type.",
|
||||
));
|
||||
}
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Unsupported or unknown login type."));
|
||||
},
|
||||
};
|
||||
|
||||
// Generate new device id if the user didn't specify one
|
||||
let device_id = body
|
||||
.device_id
|
||||
.clone()
|
||||
.unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
|
||||
let device_id = body.device_id.clone().unwrap_or_else(|| utils::random_string(DEVICE_ID_LENGTH).into());
|
||||
|
||||
// Generate a new token for the device
|
||||
let token = utils::random_string(TOKEN_LENGTH);
|
||||
|
||||
// Determine if device_id was provided and exists in the db for this user
|
||||
let device_exists = body.device_id.as_ref().map_or(false, |device_id| {
|
||||
services()
|
||||
.users
|
||||
.all_device_ids(&user_id)
|
||||
.any(|x| x.as_ref().map_or(false, |v| v == device_id))
|
||||
services().users.all_device_ids(&user_id).any(|x| x.as_ref().map_or(false, |v| v == device_id))
|
||||
});
|
||||
|
||||
if device_exists {
|
||||
services().users.set_token(&user_id, &device_id, &token)?;
|
||||
} else {
|
||||
services().users.create_device(
|
||||
&user_id,
|
||||
&device_id,
|
||||
&token,
|
||||
body.initial_device_display_name.clone(),
|
||||
)?;
|
||||
services().users.create_device(&user_id, &device_id, &token, body.initial_device_display_name.clone())?;
|
||||
}
|
||||
|
||||
// send client well-known if specified so the client knows to reconfigure itself
|
||||
let client_discovery_info = DiscoveryInfo::new(HomeserverInfo::new(
|
||||
services()
|
||||
.globals
|
||||
.well_known_client()
|
||||
.to_owned()
|
||||
.unwrap_or_else(|| "".to_owned()),
|
||||
services().globals.well_known_client().to_owned().unwrap_or_else(|| "".to_owned()),
|
||||
));
|
||||
|
||||
info!("{} logged in", user_id);
|
||||
|
||||
// home_server is deprecated but apparently must still be sent despite it being deprecated over 6 years ago.
|
||||
// initially i thought this macro was unnecessary, but ruma uses this same macro for the same reason so...
|
||||
// home_server is deprecated but apparently must still be sent despite it being
|
||||
// deprecated over 6 years ago. initially i thought this macro was unnecessary,
|
||||
// but ruma uses this same macro for the same reason so...
|
||||
#[allow(deprecated)]
|
||||
Ok(login::v3::Response {
|
||||
user_id,
|
||||
|
@ -248,7 +223,8 @@ pub async fn login_route(body: Ruma<login::v3::Request>) -> Result<login::v3::Re
|
|||
/// Log out the current device.
|
||||
///
|
||||
/// - Invalidates access token
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets to-device events
|
||||
/// - Triggers device list updates
|
||||
pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3::Response> {
|
||||
|
@ -268,15 +244,15 @@ pub async fn logout_route(body: Ruma<logout::v3::Request>) -> Result<logout::v3:
|
|||
/// Log out all devices of this user.
|
||||
///
|
||||
/// - Invalidates all access tokens
|
||||
/// - Deletes all device metadata (device id, device display name, last seen ip, last seen ts)
|
||||
/// - Deletes all device metadata (device id, device display name, last seen ip,
|
||||
/// last seen ts)
|
||||
/// - Forgets all to-device events
|
||||
/// - Triggers device list updates
|
||||
///
|
||||
/// Note: This is equivalent to calling [`GET /_matrix/client/r0/logout`](fn.logout_route.html)
|
||||
/// from each device of this user.
|
||||
pub async fn logout_all_route(
|
||||
body: Ruma<logout_all::v3::Request>,
|
||||
) -> Result<logout_all::v3::Response> {
|
||||
/// Note: This is equivalent to calling [`GET
|
||||
/// /_matrix/client/r0/logout`](fn.logout_route.html) from each device of this
|
||||
/// user.
|
||||
pub async fn logout_all_route(body: Ruma<logout_all::v3::Request>) -> Result<logout_all::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
for device_id in services().users.all_device_ids(sender_user).flatten() {
|
||||
|
|
|
@ -1,34 +1,19 @@
|
|||
use crate::{services, Result, Ruma};
|
||||
use ruma::api::client::space::get_hierarchy;
|
||||
|
||||
use crate::{services, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy``
|
||||
///
|
||||
/// Paginates over the space tree in a depth-first manner to locate child rooms of a given space.
|
||||
pub async fn get_hierarchy_route(
|
||||
body: Ruma<get_hierarchy::v1::Request>,
|
||||
) -> Result<get_hierarchy::v1::Response> {
|
||||
/// Paginates over the space tree in a depth-first manner to locate child rooms
|
||||
/// of a given space.
|
||||
pub async fn get_hierarchy_route(body: Ruma<get_hierarchy::v1::Request>) -> Result<get_hierarchy::v1::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let skip = body
|
||||
.from
|
||||
.as_ref()
|
||||
.and_then(|s| s.parse::<usize>().ok())
|
||||
.unwrap_or(0);
|
||||
let skip = body.from.as_ref().and_then(|s| s.parse::<usize>().ok()).unwrap_or(0);
|
||||
|
||||
let limit = body.limit.map_or(10, u64::from).min(100) as usize;
|
||||
|
||||
let max_depth = body.max_depth.map_or(3, u64::from).min(10) as usize + 1; // +1 to skip the space room itself
|
||||
|
||||
services()
|
||||
.rooms
|
||||
.spaces
|
||||
.get_hierarchy(
|
||||
sender_user,
|
||||
&body.room_id,
|
||||
limit,
|
||||
skip,
|
||||
max_depth,
|
||||
body.suggested_only,
|
||||
)
|
||||
.await
|
||||
services().rooms.spaces.get_hierarchy(sender_user, &body.room_id, limit, skip, max_depth, body.suggested_only).await
|
||||
}
|
||||
|
|
|
@ -1,25 +1,25 @@
|
|||
use std::sync::Arc;
|
||||
|
||||
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
|
||||
use ruma::{
|
||||
api::client::{
|
||||
error::ErrorKind,
|
||||
state::{get_state_events, get_state_events_for_key, send_state_event},
|
||||
},
|
||||
events::{
|
||||
room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType,
|
||||
},
|
||||
events::{room::canonical_alias::RoomCanonicalAliasEventContent, AnyStateEventContent, StateEventType},
|
||||
serde::Raw,
|
||||
EventId, RoomId, UserId,
|
||||
};
|
||||
use tracing::{error, log::warn};
|
||||
|
||||
use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma, RumaResponse};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}/{stateKey}`
|
||||
///
|
||||
/// Sends a state event into the room.
|
||||
///
|
||||
/// - The only requirement for the content is that it has to be valid json
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is
|
||||
/// allowed
|
||||
/// - If event is new canonical_alias: Rejects if alias is incorrect
|
||||
pub async fn send_state_event_for_key_route(
|
||||
body: Ruma<send_state_event::v3::Request>,
|
||||
|
@ -36,7 +36,9 @@ pub async fn send_state_event_for_key_route(
|
|||
.await?;
|
||||
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(send_state_event::v3::Response { event_id })
|
||||
Ok(send_state_event::v3::Response {
|
||||
event_id,
|
||||
})
|
||||
}
|
||||
|
||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/state/{eventType}`
|
||||
|
@ -44,7 +46,8 @@ pub async fn send_state_event_for_key_route(
|
|||
/// Sends a state event into the room.
|
||||
///
|
||||
/// - The only requirement for the content is that it has to be valid json
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is allowed
|
||||
/// - Tries to send the event into the room, auth rules will determine if it is
|
||||
/// allowed
|
||||
/// - If event is new canonical_alias: Rejects if alias is incorrect
|
||||
pub async fn send_state_event_for_empty_key_route(
|
||||
body: Ruma<send_state_event::v3::Request>,
|
||||
|
@ -53,10 +56,7 @@ pub async fn send_state_event_for_empty_key_route(
|
|||
|
||||
// Forbid m.room.encryption if encryption is disabled
|
||||
if body.event_type == StateEventType::RoomEncryption && !services().globals.allow_encryption() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Encryption has been disabled",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Encryption has been disabled"));
|
||||
}
|
||||
|
||||
let event_id = send_state_event_for_key_helper(
|
||||
|
@ -69,24 +69,24 @@ pub async fn send_state_event_for_empty_key_route(
|
|||
.await?;
|
||||
|
||||
let event_id = (*event_id).to_owned();
|
||||
Ok(send_state_event::v3::Response { event_id }.into())
|
||||
Ok(send_state_event::v3::Response {
|
||||
event_id,
|
||||
}
|
||||
.into())
|
||||
}
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomid}/state`
|
||||
///
|
||||
/// Get all state events for a room.
|
||||
///
|
||||
/// - If not joined: Only works if current room history visibility is world readable
|
||||
/// - If not joined: Only works if current room history visibility is world
|
||||
/// readable
|
||||
pub async fn get_state_events_route(
|
||||
body: Ruma<get_state_events::v3::Request>,
|
||||
) -> Result<get_state_events::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_state_events(sender_user, &body.room_id)?
|
||||
{
|
||||
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view the room state.",
|
||||
|
@ -108,42 +108,31 @@ pub async fn get_state_events_route(
|
|||
/// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}/{stateKey}`
|
||||
///
|
||||
/// Get single state event of a room with the specified state key.
|
||||
/// The optional query parameter `?format=event|content` allows returning the full room state event
|
||||
/// or just the state event's content (default behaviour)
|
||||
/// The optional query parameter `?format=event|content` allows returning the
|
||||
/// full room state event or just the state event's content (default behaviour)
|
||||
///
|
||||
/// - If not joined: Only works if current room history visibility is world readable
|
||||
/// - If not joined: Only works if current room history visibility is world
|
||||
/// readable
|
||||
pub async fn get_state_events_for_key_route(
|
||||
body: Ruma<get_state_events_for_key::v3::Request>,
|
||||
) -> Result<get_state_events_for_key::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_state_events(sender_user, &body.room_id)?
|
||||
{
|
||||
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view the room state.",
|
||||
));
|
||||
}
|
||||
|
||||
let event = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &body.event_type, &body.state_key)?
|
||||
.ok_or_else(|| {
|
||||
warn!(
|
||||
"State event {:?} not found in room {:?}",
|
||||
&body.event_type, &body.room_id
|
||||
);
|
||||
let event =
|
||||
services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, &body.state_key)?.ok_or_else(
|
||||
|| {
|
||||
warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
|
||||
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
|
||||
})?;
|
||||
if body
|
||||
.format
|
||||
.as_ref()
|
||||
.is_some_and(|f| f.to_lowercase().eq("event"))
|
||||
{
|
||||
},
|
||||
)?;
|
||||
if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) {
|
||||
Ok(get_state_events_for_key::v3::Response {
|
||||
content: None,
|
||||
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
|
||||
|
@ -165,43 +154,30 @@ pub async fn get_state_events_for_key_route(
|
|||
/// # `GET /_matrix/client/v3/rooms/{roomid}/state/{eventType}`
|
||||
///
|
||||
/// Get single state event of a room.
|
||||
/// The optional query parameter `?format=event|content` allows returning the full room state event
|
||||
/// or just the state event's content (default behaviour)
|
||||
/// The optional query parameter `?format=event|content` allows returning the
|
||||
/// full room state event or just the state event's content (default behaviour)
|
||||
///
|
||||
/// - If not joined: Only works if current room history visibility is world readable
|
||||
/// - If not joined: Only works if current room history visibility is world
|
||||
/// readable
|
||||
pub async fn get_state_events_for_empty_key_route(
|
||||
body: Ruma<get_state_events_for_key::v3::Request>,
|
||||
) -> Result<RumaResponse<get_state_events_for_key::v3::Response>> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_can_see_state_events(sender_user, &body.room_id)?
|
||||
{
|
||||
if !services().rooms.state_accessor.user_can_see_state_events(sender_user, &body.room_id)? {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You don't have permission to view the room state.",
|
||||
));
|
||||
}
|
||||
|
||||
let event = services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&body.room_id, &body.event_type, "")?
|
||||
.ok_or_else(|| {
|
||||
warn!(
|
||||
"State event {:?} not found in room {:?}",
|
||||
&body.event_type, &body.room_id
|
||||
);
|
||||
let event =
|
||||
services().rooms.state_accessor.room_state_get(&body.room_id, &body.event_type, "")?.ok_or_else(|| {
|
||||
warn!("State event {:?} not found in room {:?}", &body.event_type, &body.room_id);
|
||||
Error::BadRequest(ErrorKind::NotFound, "State event not found.")
|
||||
})?;
|
||||
|
||||
if body
|
||||
.format
|
||||
.as_ref()
|
||||
.is_some_and(|f| f.to_lowercase().eq("event"))
|
||||
{
|
||||
if body.format.as_ref().is_some_and(|f| f.to_lowercase().eq("event")) {
|
||||
Ok(get_state_events_for_key::v3::Response {
|
||||
content: None,
|
||||
event: serde_json::from_str(event.to_state_event().json().get()).map_err(|e| {
|
||||
|
@ -223,19 +199,13 @@ pub async fn get_state_events_for_empty_key_route(
|
|||
}
|
||||
|
||||
async fn send_state_event_for_key_helper(
|
||||
sender: &UserId,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
json: &Raw<AnyStateEventContent>,
|
||||
state_key: String,
|
||||
sender: &UserId, room_id: &RoomId, event_type: &StateEventType, json: &Raw<AnyStateEventContent>, state_key: String,
|
||||
) -> Result<Arc<EventId>> {
|
||||
let sender_user = sender;
|
||||
|
||||
// TODO: Review this check, error if event is unparsable, use event type, allow alias if it
|
||||
// previously existed
|
||||
if let Ok(canonical_alias) =
|
||||
serde_json::from_str::<RoomCanonicalAliasEventContent>(json.json().get())
|
||||
{
|
||||
// TODO: Review this check, error if event is unparsable, use event type, allow
|
||||
// alias if it previously existed
|
||||
if let Ok(canonical_alias) = serde_json::from_str::<RoomCanonicalAliasEventContent>(json.json().get()) {
|
||||
let mut aliases = canonical_alias.alt_aliases.clone();
|
||||
|
||||
if let Some(alias) = canonical_alias.alias {
|
||||
|
@ -253,22 +223,14 @@ async fn send_state_event_for_key_helper(
|
|||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You are only allowed to send canonical_alias \
|
||||
events when it's aliases already exists",
|
||||
"You are only allowed to send canonical_alias events when it's aliases already exists",
|
||||
));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let mutex_state = Arc::clone(
|
||||
services()
|
||||
.globals
|
||||
.roomid_mutex_state
|
||||
.write()
|
||||
.unwrap()
|
||||
.entry(room_id.to_owned())
|
||||
.or_default(),
|
||||
);
|
||||
let mutex_state =
|
||||
Arc::clone(services().globals.roomid_mutex_state.write().unwrap().entry(room_id.to_owned()).or_default());
|
||||
let state_lock = mutex_state.lock().await;
|
||||
|
||||
let event_id = services()
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -1,4 +1,5 @@
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::{
|
||||
api::client::tag::{create_tag, delete_tag, get_tags},
|
||||
events::{
|
||||
|
@ -6,29 +7,21 @@ use ruma::{
|
|||
RoomAccountDataEventType,
|
||||
},
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/user/{userId}/rooms/{roomId}/tags/{tag}`
|
||||
///
|
||||
/// Adds a tag to the room.
|
||||
///
|
||||
/// - Inserts the tag into the tag event of the room account data.
|
||||
pub async fn update_tag_route(
|
||||
body: Ruma<create_tag::v3::Request>,
|
||||
) -> Result<create_tag::v3::Response> {
|
||||
pub async fn update_tag_route(body: Ruma<create_tag::v3::Request>) -> Result<create_tag::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event = services().account_data.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?;
|
||||
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
|
||||
|
||||
let mut tags_event = event
|
||||
.map(|e| {
|
||||
serde_json::from_str(e.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
||||
})
|
||||
.map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
|
@ -37,10 +30,7 @@ pub async fn update_tag_route(
|
|||
})
|
||||
})?;
|
||||
|
||||
tags_event
|
||||
.content
|
||||
.tags
|
||||
.insert(body.tag.clone().into(), body.tag_info.clone());
|
||||
tags_event.content.tags.insert(body.tag.clone().into(), body.tag_info.clone());
|
||||
|
||||
services().account_data.update(
|
||||
Some(&body.room_id),
|
||||
|
@ -57,22 +47,13 @@ pub async fn update_tag_route(
|
|||
/// Deletes a tag from the room.
|
||||
///
|
||||
/// - Removes the tag from the tag event of the room account data.
|
||||
pub async fn delete_tag_route(
|
||||
body: Ruma<delete_tag::v3::Request>,
|
||||
) -> Result<delete_tag::v3::Response> {
|
||||
pub async fn delete_tag_route(body: Ruma<delete_tag::v3::Request>) -> Result<delete_tag::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event = services().account_data.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?;
|
||||
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
|
||||
|
||||
let mut tags_event = event
|
||||
.map(|e| {
|
||||
serde_json::from_str(e.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
||||
})
|
||||
.map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
|
@ -101,17 +82,10 @@ pub async fn delete_tag_route(
|
|||
pub async fn get_tags_route(body: Ruma<get_tags::v3::Request>) -> Result<get_tags::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
let event = services().account_data.get(
|
||||
Some(&body.room_id),
|
||||
sender_user,
|
||||
RoomAccountDataEventType::Tag,
|
||||
)?;
|
||||
let event = services().account_data.get(Some(&body.room_id), sender_user, RoomAccountDataEventType::Tag)?;
|
||||
|
||||
let tags_event = event
|
||||
.map(|e| {
|
||||
serde_json::from_str(e.get())
|
||||
.map_err(|_| Error::bad_database("Invalid account data event in db."))
|
||||
})
|
||||
.map(|e| serde_json::from_str(e.get()).map_err(|_| Error::bad_database("Invalid account data event in db.")))
|
||||
.unwrap_or_else(|| {
|
||||
Ok(TagEvent {
|
||||
content: TagEventContent {
|
||||
|
|
|
@ -1,14 +1,13 @@
|
|||
use crate::{Result, Ruma};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use ruma::api::client::thirdparty::get_protocols;
|
||||
|
||||
use std::collections::BTreeMap;
|
||||
use crate::{Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/thirdparty/protocols`
|
||||
///
|
||||
/// TODO: Fetches all metadata about protocols supported by the homeserver.
|
||||
pub async fn get_protocols_route(
|
||||
_body: Ruma<get_protocols::v3::Request>,
|
||||
) -> Result<get_protocols::v3::Response> {
|
||||
pub async fn get_protocols_route(_body: Ruma<get_protocols::v3::Request>) -> Result<get_protocols::v3::Response> {
|
||||
// TODO
|
||||
Ok(get_protocols::v3::Response {
|
||||
protocols: BTreeMap::new(),
|
||||
|
|
|
@ -3,21 +3,14 @@ use ruma::api::client::{error::ErrorKind, threads::get_threads};
|
|||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `GET /_matrix/client/r0/rooms/{roomId}/threads`
|
||||
pub async fn get_threads_route(
|
||||
body: Ruma<get_threads::v1::Request>,
|
||||
) -> Result<get_threads::v1::Response> {
|
||||
pub async fn get_threads_route(body: Ruma<get_threads::v1::Request>) -> Result<get_threads::v1::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
// Use limit or else 10, with maximum 100
|
||||
let limit = body
|
||||
.limit
|
||||
.and_then(|l| l.try_into().ok())
|
||||
.unwrap_or(10)
|
||||
.min(100);
|
||||
let limit = body.limit.and_then(|l| l.try_into().ok()).unwrap_or(10).min(100);
|
||||
|
||||
let from = if let Some(from) = &body.from {
|
||||
from.parse()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))?
|
||||
from.parse().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))?
|
||||
} else {
|
||||
u64::MAX
|
||||
};
|
||||
|
@ -40,10 +33,7 @@ pub async fn get_threads_route(
|
|||
let next_batch = threads.last().map(|(count, _)| count.to_string());
|
||||
|
||||
Ok(get_threads::v1::Response {
|
||||
chunk: threads
|
||||
.into_iter()
|
||||
.map(|(_, pdu)| pdu.to_room_event())
|
||||
.collect(),
|
||||
chunk: threads.into_iter().map(|(_, pdu)| pdu.to_room_event()).collect(),
|
||||
next_batch,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,6 +1,5 @@
|
|||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
use ruma::{
|
||||
api::{
|
||||
client::{error::ErrorKind, to_device::send_event_to_device},
|
||||
|
@ -9,6 +8,8 @@ use ruma::{
|
|||
to_device::DeviceIdOrAllDevices,
|
||||
};
|
||||
|
||||
use crate::{services, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}`
|
||||
///
|
||||
/// Send a to-device event to a set of client devices.
|
||||
|
@ -19,11 +20,7 @@ pub async fn send_event_to_device_route(
|
|||
let sender_device = body.sender_device.as_deref();
|
||||
|
||||
// Check if this is a new transaction id
|
||||
if services()
|
||||
.transaction_ids
|
||||
.existing_txnid(sender_user, sender_device, &body.txn_id)?
|
||||
.is_some()
|
||||
{
|
||||
if services().transaction_ids.existing_txnid(sender_user, sender_device, &body.txn_id)?.is_some() {
|
||||
return Ok(send_event_to_device::v3::Response {});
|
||||
}
|
||||
|
||||
|
@ -38,14 +35,12 @@ pub async fn send_event_to_device_route(
|
|||
|
||||
services().sending.send_reliable_edu(
|
||||
target_user_id.server_name(),
|
||||
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(
|
||||
DirectDeviceContent {
|
||||
serde_json::to_vec(&federation::transactions::edu::Edu::DirectToDevice(DirectDeviceContent {
|
||||
sender: sender_user.clone(),
|
||||
ev_type: body.event_type.clone(),
|
||||
message_id: count.to_string().into(),
|
||||
messages,
|
||||
},
|
||||
))
|
||||
}))
|
||||
.expect("DirectToDevice EDU can be serialized"),
|
||||
count,
|
||||
)?;
|
||||
|
@ -60,11 +55,11 @@ pub async fn send_event_to_device_route(
|
|||
target_user_id,
|
||||
target_device_id,
|
||||
&body.event_type.to_string(),
|
||||
event.deserialize_as().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
|
||||
})?,
|
||||
event
|
||||
.deserialize_as()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
|
||||
)?;
|
||||
}
|
||||
},
|
||||
|
||||
DeviceIdOrAllDevices::AllDevices => {
|
||||
for target_device_id in services().users.all_device_ids(target_user_id) {
|
||||
|
@ -73,20 +68,18 @@ pub async fn send_event_to_device_route(
|
|||
target_user_id,
|
||||
&target_device_id?,
|
||||
&body.event_type.to_string(),
|
||||
event.deserialize_as().map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid")
|
||||
})?,
|
||||
event
|
||||
.deserialize_as()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event is invalid"))?,
|
||||
)?;
|
||||
}
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Save transaction id with empty data
|
||||
services()
|
||||
.transaction_ids
|
||||
.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
|
||||
services().transaction_ids.add_txnid(sender_user, sender_device, &body.txn_id, &[])?;
|
||||
|
||||
Ok(send_event_to_device::v3::Response {})
|
||||
}
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::{services, utils, Error, Result, Ruma};
|
||||
use ruma::api::client::{error::ErrorKind, typing::create_typing_event};
|
||||
|
||||
use crate::{services, utils, Error, Result, Ruma};
|
||||
|
||||
/// # `PUT /_matrix/client/r0/rooms/{roomId}/typing/{userId}`
|
||||
///
|
||||
/// Sets the typing state of the sender user.
|
||||
|
@ -11,15 +12,8 @@ pub async fn create_typing_event_route(
|
|||
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
|
||||
if !services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.is_joined(sender_user, &body.room_id)?
|
||||
{
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"You are not in this room.",
|
||||
));
|
||||
if !services().rooms.state_cache.is_joined(sender_user, &body.room_id)? {
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "You are not in this room."));
|
||||
}
|
||||
|
||||
if let Typing::Yes(duration) = body.state {
|
||||
|
@ -29,11 +23,7 @@ pub async fn create_typing_event_route(
|
|||
duration.as_millis() as u64 + utils::millis_since_unix_epoch(),
|
||||
)?;
|
||||
} else {
|
||||
services()
|
||||
.rooms
|
||||
.edus
|
||||
.typing
|
||||
.typing_remove(sender_user, &body.room_id)?;
|
||||
services().rooms.edus.typing.typing_remove(sender_user, &body.room_id)?;
|
||||
}
|
||||
|
||||
Ok(create_typing_event::v3::Response {})
|
||||
|
|
|
@ -7,14 +7,16 @@ use crate::{services, Error, Result, Ruma};
|
|||
|
||||
/// # `GET /_matrix/client/versions`
|
||||
///
|
||||
/// Get the versions of the specification and unstable features supported by this server.
|
||||
/// Get the versions of the specification and unstable features supported by
|
||||
/// this server.
|
||||
///
|
||||
/// - Versions take the form MAJOR.MINOR.PATCH
|
||||
/// - Only the latest PATCH release will be reported for each MAJOR.MINOR value
|
||||
/// - Unstable features are namespaced and may include version information in their name
|
||||
/// - Unstable features are namespaced and may include version information in
|
||||
/// their name
|
||||
///
|
||||
/// Note: Unstable features are used while developing new features. Clients should avoid using
|
||||
/// unstable features in their stable releases
|
||||
/// Note: Unstable features are used while developing new features. Clients
|
||||
/// should avoid using unstable features in their stable releases
|
||||
pub async fn get_supported_versions_route(
|
||||
_body: Ruma<get_supported_versions::Request>,
|
||||
) -> Result<get_supported_versions::Response> {
|
||||
|
@ -60,8 +62,8 @@ pub async fn well_known_client_route() -> Result<impl IntoResponse> {
|
|||
|
||||
/// # `GET /client/server.json`
|
||||
///
|
||||
/// Endpoint provided by sliding sync proxy used by some clients such as Element Web
|
||||
/// as a non-standard health check.
|
||||
/// Endpoint provided by sliding sync proxy used by some clients such as Element
|
||||
/// Web as a non-standard health check.
|
||||
pub async fn syncv3_client_server_json() -> Result<impl IntoResponse> {
|
||||
let server_url = match services().globals.well_known_client() {
|
||||
Some(url) => url.clone(),
|
||||
|
|
|
@ -1,4 +1,3 @@
|
|||
use crate::{services, Result, Ruma};
|
||||
use ruma::{
|
||||
api::client::user_directory::search_users,
|
||||
events::{
|
||||
|
@ -7,15 +6,16 @@ use ruma::{
|
|||
},
|
||||
};
|
||||
|
||||
use crate::{services, Result, Ruma};
|
||||
|
||||
/// # `POST /_matrix/client/r0/user_directory/search`
|
||||
///
|
||||
/// Searches all known users for a match.
|
||||
///
|
||||
/// - Hides any local users that aren't in any public rooms (i.e. those that have the join rule set to public)
|
||||
/// - Hides any local users that aren't in any public rooms (i.e. those that
|
||||
/// have the join rule set to public)
|
||||
/// and don't share a room with the sender
|
||||
pub async fn search_users_route(
|
||||
body: Ruma<search_users::v3::Request>,
|
||||
) -> Result<search_users::v3::Response> {
|
||||
pub async fn search_users_route(body: Ruma<search_users::v3::Request>) -> Result<search_users::v3::Response> {
|
||||
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
|
||||
let limit = u64::from(body.limit) as usize;
|
||||
|
||||
|
@ -29,56 +29,37 @@ pub async fn search_users_route(
|
|||
avatar_url: services().users.avatar_url(&user_id).ok()?,
|
||||
};
|
||||
|
||||
let user_id_matches = user
|
||||
.user_id
|
||||
.to_string()
|
||||
.to_lowercase()
|
||||
.contains(&body.search_term.to_lowercase());
|
||||
let user_id_matches = user.user_id.to_string().to_lowercase().contains(&body.search_term.to_lowercase());
|
||||
|
||||
let user_displayname_matches = user
|
||||
.display_name
|
||||
.as_ref()
|
||||
.filter(|name| {
|
||||
name.to_lowercase()
|
||||
.contains(&body.search_term.to_lowercase())
|
||||
})
|
||||
.filter(|name| name.to_lowercase().contains(&body.search_term.to_lowercase()))
|
||||
.is_some();
|
||||
|
||||
if !user_id_matches && !user_displayname_matches {
|
||||
return None;
|
||||
}
|
||||
|
||||
let user_is_in_public_rooms = services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(&user_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
.any(|room| {
|
||||
services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room, &StateEventType::RoomJoinRules, "")
|
||||
.map_or(false, |event| {
|
||||
let user_is_in_public_rooms =
|
||||
services().rooms.state_cache.rooms_joined(&user_id).filter_map(std::result::Result::ok).any(|room| {
|
||||
services().rooms.state_accessor.room_state_get(&room, &StateEventType::RoomJoinRules, "").map_or(
|
||||
false,
|
||||
|event| {
|
||||
event.map_or(false, |event| {
|
||||
serde_json::from_str(event.content.get())
|
||||
.map_or(false, |r: RoomJoinRulesEventContent| {
|
||||
r.join_rule == JoinRule::Public
|
||||
})
|
||||
})
|
||||
.map_or(false, |r: RoomJoinRulesEventContent| r.join_rule == JoinRule::Public)
|
||||
})
|
||||
},
|
||||
)
|
||||
});
|
||||
|
||||
if user_is_in_public_rooms {
|
||||
return Some(user);
|
||||
}
|
||||
|
||||
let user_is_in_shared_rooms = services()
|
||||
.rooms
|
||||
.user
|
||||
.get_shared_rooms(vec![sender_user.clone(), user_id])
|
||||
.ok()?
|
||||
.next()
|
||||
.is_some();
|
||||
let user_is_in_shared_rooms =
|
||||
services().rooms.user.get_shared_rooms(vec![sender_user.clone(), user_id]).ok()?.next().is_some();
|
||||
|
||||
if user_is_in_shared_rooms {
|
||||
return Some(user);
|
||||
|
@ -90,5 +71,8 @@ pub async fn search_users_route(
|
|||
let results = users.by_ref().take(limit).collect();
|
||||
let limited = users.next().is_some();
|
||||
|
||||
Ok(search_users::v3::Response { results, limited })
|
||||
Ok(search_users::v3::Response {
|
||||
results,
|
||||
limited,
|
||||
})
|
||||
}
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
use crate::{services, Result, Ruma};
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use hmac::{Hmac, Mac};
|
||||
use ruma::{api::client::voip::get_turn_server_info, SecondsSinceUnixEpoch};
|
||||
use sha1::Sha1;
|
||||
use std::time::{Duration, SystemTime};
|
||||
|
||||
use crate::{services, Result, Ruma};
|
||||
|
||||
type HmacSha1 = Hmac<Sha1>;
|
||||
|
||||
|
@ -25,8 +27,7 @@ pub async fn turn_server_route(
|
|||
|
||||
let username: String = format!("{}:{}", expiry.get(), sender_user);
|
||||
|
||||
let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes())
|
||||
.expect("HMAC can take key of any size");
|
||||
let mut mac = HmacSha1::new_from_slice(turn_secret.as_bytes()).expect("HMAC can take key of any size");
|
||||
mac.update(username.as_bytes());
|
||||
|
||||
let password: String = general_purpose::STANDARD.encode(mac.finalize().into_bytes());
|
||||
|
|
|
@ -43,18 +43,16 @@ where
|
|||
let (mut parts, mut body) = match req.with_limited_body() {
|
||||
Ok(limited_req) => {
|
||||
let (parts, body) = limited_req.into_parts();
|
||||
let body = to_bytes(body)
|
||||
.await
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||
let body =
|
||||
to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||
(parts, body)
|
||||
}
|
||||
},
|
||||
Err(original_req) => {
|
||||
let (parts, body) = original_req.into_parts();
|
||||
let body = to_bytes(body)
|
||||
.await
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||
let body =
|
||||
to_bytes(body).await.map_err(|_| Error::BadRequest(ErrorKind::MissingToken, "Missing token."))?;
|
||||
(parts, body)
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let metadata = T::METADATA;
|
||||
|
@ -66,11 +64,8 @@ where
|
|||
Ok(params) => params,
|
||||
Err(e) => {
|
||||
error!(%query, "Failed to deserialize query parameters: {}", e);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Unknown,
|
||||
"Failed to read query parameters",
|
||||
));
|
||||
}
|
||||
return Err(Error::BadRequest(ErrorKind::Unknown, "Failed to read query parameters"));
|
||||
},
|
||||
};
|
||||
|
||||
let token = match &auth_header {
|
||||
|
@ -81,12 +76,12 @@ where
|
|||
let mut json_body = serde_json::from_slice::<CanonicalJsonValue>(&body).ok();
|
||||
|
||||
let appservices = services().appservice.all().unwrap();
|
||||
let appservice_registration = appservices
|
||||
.iter()
|
||||
.find(|(_id, registration)| Some(registration.as_token.as_str()) == token);
|
||||
let appservice_registration =
|
||||
appservices.iter().find(|(_id, registration)| Some(registration.as_token.as_str()) == token);
|
||||
|
||||
let (sender_user, sender_device, sender_servername, from_appservice) =
|
||||
if let Some((_id, registration)) = appservice_registration {
|
||||
let (sender_user, sender_device, sender_servername, from_appservice) = if let Some((_id, registration)) =
|
||||
appservice_registration
|
||||
{
|
||||
match metadata.authentication {
|
||||
AuthScheme::AccessToken => {
|
||||
let user_id = query_params.user_id.map_or_else(
|
||||
|
@ -101,15 +96,12 @@ where
|
|||
);
|
||||
|
||||
if !services().users.exists(&user_id).unwrap() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"User does not exist.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "User does not exist."));
|
||||
}
|
||||
|
||||
// TODO: Check if appservice is allowed to be that user
|
||||
(Some(user_id), None, None, true)
|
||||
}
|
||||
},
|
||||
AuthScheme::ServerSignatures => (None, None, None, true),
|
||||
AuthScheme::None => (None, None, None, true),
|
||||
}
|
||||
|
@ -118,92 +110,62 @@ where
|
|||
AuthScheme::AccessToken => {
|
||||
let token = match token {
|
||||
Some(token) => token,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::MissingToken,
|
||||
"Missing access token.",
|
||||
))
|
||||
}
|
||||
_ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
|
||||
};
|
||||
|
||||
match services().users.find_from_token(token).unwrap() {
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken { soft_logout: false },
|
||||
ErrorKind::UnknownToken {
|
||||
soft_logout: false,
|
||||
},
|
||||
"Unknown access token.",
|
||||
))
|
||||
},
|
||||
Some((user_id, device_id)) => {
|
||||
(Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
|
||||
},
|
||||
}
|
||||
Some((user_id, device_id)) => (
|
||||
Some(user_id),
|
||||
Some(OwnedDeviceId::from(device_id)),
|
||||
None,
|
||||
false,
|
||||
),
|
||||
}
|
||||
}
|
||||
},
|
||||
AuthScheme::ServerSignatures => {
|
||||
let TypedHeader(Authorization(x_matrix)) = parts
|
||||
.extract::<TypedHeader<Authorization<XMatrix>>>()
|
||||
.await
|
||||
.map_err(|e| {
|
||||
let TypedHeader(Authorization(x_matrix)) =
|
||||
parts.extract::<TypedHeader<Authorization<XMatrix>>>().await.map_err(|e| {
|
||||
warn!("Missing or invalid Authorization header: {}", e);
|
||||
|
||||
let msg = match e.reason() {
|
||||
TypedHeaderRejectionReason::Missing => {
|
||||
"Missing Authorization header."
|
||||
}
|
||||
TypedHeaderRejectionReason::Error(_) => {
|
||||
"Invalid X-Matrix signatures."
|
||||
}
|
||||
TypedHeaderRejectionReason::Missing => "Missing Authorization header.",
|
||||
TypedHeaderRejectionReason::Error(_) => "Invalid X-Matrix signatures.",
|
||||
_ => "Unknown header-related error",
|
||||
};
|
||||
|
||||
Error::BadRequest(ErrorKind::Forbidden, msg)
|
||||
})?;
|
||||
|
||||
let origin_signatures = BTreeMap::from_iter([(
|
||||
x_matrix.key.clone(),
|
||||
CanonicalJsonValue::String(x_matrix.sig),
|
||||
)]);
|
||||
let origin_signatures =
|
||||
BTreeMap::from_iter([(x_matrix.key.clone(), CanonicalJsonValue::String(x_matrix.sig))]);
|
||||
|
||||
let signatures = BTreeMap::from_iter([(
|
||||
x_matrix.origin.as_str().to_owned(),
|
||||
CanonicalJsonValue::Object(origin_signatures),
|
||||
)]);
|
||||
|
||||
let server_destination =
|
||||
services().globals.server_name().as_str().to_owned();
|
||||
let server_destination = services().globals.server_name().as_str().to_owned();
|
||||
|
||||
if let Some(destination) = x_matrix.destination.as_ref() {
|
||||
if destination != &server_destination {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Invalid authorization.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Invalid authorization."));
|
||||
}
|
||||
}
|
||||
|
||||
let mut request_map = BTreeMap::from_iter([
|
||||
(
|
||||
"method".to_owned(),
|
||||
CanonicalJsonValue::String(parts.method.to_string()),
|
||||
),
|
||||
(
|
||||
"uri".to_owned(),
|
||||
CanonicalJsonValue::String(parts.uri.to_string()),
|
||||
),
|
||||
("method".to_owned(), CanonicalJsonValue::String(parts.method.to_string())),
|
||||
("uri".to_owned(), CanonicalJsonValue::String(parts.uri.to_string())),
|
||||
(
|
||||
"origin".to_owned(),
|
||||
CanonicalJsonValue::String(x_matrix.origin.as_str().to_owned()),
|
||||
),
|
||||
(
|
||||
"destination".to_owned(),
|
||||
CanonicalJsonValue::String(server_destination),
|
||||
),
|
||||
(
|
||||
"signatures".to_owned(),
|
||||
CanonicalJsonValue::Object(signatures),
|
||||
),
|
||||
("destination".to_owned(), CanonicalJsonValue::String(server_destination)),
|
||||
("signatures".to_owned(), CanonicalJsonValue::Object(signatures)),
|
||||
]);
|
||||
|
||||
if let Some(json_body) = &json_body {
|
||||
|
@ -213,25 +175,18 @@ where
|
|||
let keys_result = services()
|
||||
.rooms
|
||||
.event_handler
|
||||
.fetch_signing_keys_for_server(
|
||||
&x_matrix.origin,
|
||||
vec![x_matrix.key.clone()],
|
||||
)
|
||||
.fetch_signing_keys_for_server(&x_matrix.origin, vec![x_matrix.key.clone()])
|
||||
.await;
|
||||
|
||||
let keys = match keys_result {
|
||||
Ok(b) => b,
|
||||
Err(e) => {
|
||||
warn!("Failed to fetch signing keys: {}", e);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"Failed to fetch signing keys.",
|
||||
));
|
||||
}
|
||||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Failed to fetch signing keys."));
|
||||
},
|
||||
};
|
||||
|
||||
let pub_key_map =
|
||||
BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]);
|
||||
let pub_key_map = BTreeMap::from_iter([(x_matrix.origin.as_str().to_owned(), keys)]);
|
||||
|
||||
match ruma::signatures::verify_json(&pub_key_map, &request_map) {
|
||||
Ok(()) => (None, None, Some(x_matrix.origin), false),
|
||||
|
@ -243,9 +198,8 @@ where
|
|||
|
||||
if parts.uri.to_string().contains('@') {
|
||||
warn!(
|
||||
"Request uri contained '@' character. Make sure your \
|
||||
reverse proxy gives Conduit the raw uri (apache: use \
|
||||
nocanon)"
|
||||
"Request uri contained '@' character. Make sure your reverse proxy gives Conduit \
|
||||
the raw uri (apache: use nocanon)"
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -253,45 +207,35 @@ where
|
|||
ErrorKind::Forbidden,
|
||||
"Failed to verify X-Matrix signatures.",
|
||||
));
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
AuthScheme::None => match parts.uri.path() {
|
||||
// allow_public_room_directory_without_auth
|
||||
"/_matrix/client/v3/publicRooms" | "/_matrix/client/r0/publicRooms" => {
|
||||
if !services()
|
||||
.globals
|
||||
.config
|
||||
.allow_public_room_directory_without_auth
|
||||
{
|
||||
if !services().globals.config.allow_public_room_directory_without_auth {
|
||||
let token = match token {
|
||||
Some(token) => token,
|
||||
_ => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::MissingToken,
|
||||
"Missing access token.",
|
||||
))
|
||||
}
|
||||
_ => return Err(Error::BadRequest(ErrorKind::MissingToken, "Missing access token.")),
|
||||
};
|
||||
|
||||
match services().users.find_from_token(token).unwrap() {
|
||||
None => {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::UnknownToken { soft_logout: false },
|
||||
ErrorKind::UnknownToken {
|
||||
soft_logout: false,
|
||||
},
|
||||
"Unknown access token.",
|
||||
))
|
||||
}
|
||||
Some((user_id, device_id)) => (
|
||||
Some(user_id),
|
||||
Some(OwnedDeviceId::from(device_id)),
|
||||
None,
|
||||
false,
|
||||
),
|
||||
},
|
||||
Some((user_id, device_id)) => {
|
||||
(Some(user_id), Some(OwnedDeviceId::from(device_id)), None, false)
|
||||
},
|
||||
}
|
||||
} else {
|
||||
(None, None, None, false)
|
||||
}
|
||||
}
|
||||
},
|
||||
_ => (None, None, None, false),
|
||||
},
|
||||
}
|
||||
|
@ -302,8 +246,7 @@ where
|
|||
|
||||
if let Some(CanonicalJsonValue::Object(json_body)) = &mut json_body {
|
||||
let user_id = sender_user.clone().unwrap_or_else(|| {
|
||||
UserId::parse_with_server_name("", services().globals.server_name())
|
||||
.expect("we know this is valid")
|
||||
UserId::parse_with_server_name("", services().globals.server_name()).expect("we know this is valid")
|
||||
});
|
||||
|
||||
let uiaa_request = json_body
|
||||
|
@ -367,9 +310,7 @@ impl Credentials for XMatrix {
|
|||
"HeaderValue to decode should start with \"X-Matrix ..\", received = {value:?}",
|
||||
);
|
||||
|
||||
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..])
|
||||
.ok()?
|
||||
.trim_start();
|
||||
let parameters = str::from_utf8(&value.as_bytes()["X-Matrix ".len()..]).ok()?.trim_start();
|
||||
|
||||
let mut origin = None;
|
||||
let mut destination = None;
|
||||
|
@ -381,10 +322,7 @@ impl Credentials for XMatrix {
|
|||
|
||||
// It's not at all clear why some fields are quoted and others not in the spec,
|
||||
// let's simply accept either form for every field.
|
||||
let value = value
|
||||
.strip_prefix('"')
|
||||
.and_then(|rest| rest.strip_suffix('"'))
|
||||
.unwrap_or(value);
|
||||
let value = value.strip_prefix('"').and_then(|rest| rest.strip_suffix('"')).unwrap_or(value);
|
||||
|
||||
// FIXME: Catch multiple fields of the same name
|
||||
match name {
|
||||
|
@ -392,10 +330,7 @@ impl Credentials for XMatrix {
|
|||
"key" => key = Some(value.to_owned()),
|
||||
"sig" => sig = Some(value.to_owned()),
|
||||
"destination" => destination = Some(value.to_owned()),
|
||||
_ => debug!(
|
||||
"Unexpected field `{}` in X-Matrix Authorization header",
|
||||
name
|
||||
),
|
||||
_ => debug!("Unexpected field `{}` in X-Matrix Authorization header", name),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -407,9 +342,7 @@ impl Credentials for XMatrix {
|
|||
})
|
||||
}
|
||||
|
||||
fn encode(&self) -> http::HeaderValue {
|
||||
todo!()
|
||||
}
|
||||
fn encode(&self) -> http::HeaderValue { todo!() }
|
||||
}
|
||||
|
||||
impl<T: OutgoingResponse> IntoResponse for RumaResponse<T> {
|
||||
|
|
|
@ -1,10 +1,9 @@
|
|||
use crate::Error;
|
||||
use ruma::{
|
||||
api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName,
|
||||
OwnedUserId,
|
||||
};
|
||||
use std::ops::Deref;
|
||||
|
||||
use ruma::{api::client::uiaa::UiaaResponse, CanonicalJsonValue, OwnedDeviceId, OwnedServerName, OwnedUserId};
|
||||
|
||||
use crate::Error;
|
||||
|
||||
#[cfg(feature = "conduit_bin")]
|
||||
mod axum;
|
||||
|
||||
|
@ -22,22 +21,16 @@ pub struct Ruma<T> {
|
|||
impl<T> Deref for Ruma<T> {
|
||||
type Target = T;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.body
|
||||
}
|
||||
fn deref(&self) -> &Self::Target { &self.body }
|
||||
}
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct RumaResponse<T>(pub T);
|
||||
|
||||
impl<T> From<T> for RumaResponse<T> {
|
||||
fn from(t: T) -> Self {
|
||||
Self(t)
|
||||
}
|
||||
fn from(t: T) -> Self { Self(t) }
|
||||
}
|
||||
|
||||
impl From<Error> for RumaResponse<UiaaResponse> {
|
||||
fn from(t: Error) -> Self {
|
||||
t.to_response()
|
||||
}
|
||||
fn from(t: Error) -> Self { t.to_response() }
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -178,50 +178,51 @@ pub struct TlsConfig {
|
|||
pub key: String,
|
||||
#[serde(default)]
|
||||
/// Whether to listen and allow for HTTP and HTTPS connections (insecure!)
|
||||
/// Only works / does something if the `axum_dual_protocol` feature flag was built
|
||||
/// Only works / does something if the `axum_dual_protocol` feature flag was
|
||||
/// built
|
||||
pub dual_protocol: bool,
|
||||
}
|
||||
|
||||
const DEPRECATED_KEYS: &[&str] = &["cache_capacity"];
|
||||
|
||||
impl Config {
|
||||
/// Iterates over all the keys in the config file and warns if there is a deprecated key specified
|
||||
/// Iterates over all the keys in the config file and warns if there is a
|
||||
/// deprecated key specified
|
||||
pub fn warn_deprecated(&self) {
|
||||
debug!("Checking for deprecated config keys");
|
||||
let mut was_deprecated = false;
|
||||
for key in self
|
||||
.catchall
|
||||
.keys()
|
||||
.filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key))
|
||||
{
|
||||
for key in self.catchall.keys().filter(|key| DEPRECATED_KEYS.iter().any(|s| s == key)) {
|
||||
warn!("Config parameter \"{}\" is deprecated, ignoring.", key);
|
||||
was_deprecated = true;
|
||||
}
|
||||
|
||||
if was_deprecated {
|
||||
warn!("Read conduit documentation and check your configuration if any new configuration parameters should be adjusted");
|
||||
}
|
||||
}
|
||||
|
||||
/// iterates over all the catchall keys (unknown config options) and warns if there are any.
|
||||
pub fn warn_unknown_key(&self) {
|
||||
debug!("Checking for unknown config keys");
|
||||
for key in self.catchall.keys().filter(
|
||||
|key| "config".to_owned().ne(key.to_owned()), /* "config" is expected */
|
||||
) {
|
||||
warn!(
|
||||
"Config parameter \"{}\" is unknown to conduwuit, ignoring.",
|
||||
key
|
||||
"Read conduit documentation and check your configuration if any new configuration parameters should \
|
||||
be adjusted"
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks the presence of the `address` and `unix_socket_path` keys in the raw_config, exiting the process if both keys were detected.
|
||||
/// iterates over all the catchall keys (unknown config options) and warns
|
||||
/// if there are any.
|
||||
pub fn warn_unknown_key(&self) {
|
||||
debug!("Checking for unknown config keys");
|
||||
for key in
|
||||
self.catchall.keys().filter(|key| "config".to_owned().ne(key.to_owned()) /* "config" is expected */)
|
||||
{
|
||||
warn!("Config parameter \"{}\" is unknown to conduwuit, ignoring.", key);
|
||||
}
|
||||
}
|
||||
|
||||
/// Checks the presence of the `address` and `unix_socket_path` keys in the
|
||||
/// raw_config, exiting the process if both keys were detected.
|
||||
pub fn is_dual_listening(&self, raw_config: Figment) -> bool {
|
||||
let check_address = raw_config.find_value("address");
|
||||
let check_unix_socket = raw_config.find_value("unix_socket_path");
|
||||
|
||||
// are the check_address and check_unix_socket keys both Ok (specified) at the same time?
|
||||
// are the check_address and check_unix_socket keys both Ok (specified) at the
|
||||
// same time?
|
||||
if check_address.is_ok() && check_unix_socket.is_ok() {
|
||||
error!("TOML keys \"address\" and \"unix_socket_path\" were both defined. Please specify only one option.");
|
||||
return true;
|
||||
|
@ -238,28 +239,13 @@ impl fmt::Display for Config {
|
|||
("Server name", self.server_name.host()),
|
||||
("Database backend", &self.database_backend),
|
||||
("Database path", &self.database_path),
|
||||
(
|
||||
"Database cache capacity (MB)",
|
||||
&self.db_cache_capacity_mb.to_string(),
|
||||
),
|
||||
(
|
||||
"Cache capacity modifier",
|
||||
&self.conduit_cache_capacity_modifier.to_string(),
|
||||
),
|
||||
("Database cache capacity (MB)", &self.db_cache_capacity_mb.to_string()),
|
||||
("Cache capacity modifier", &self.conduit_cache_capacity_modifier.to_string()),
|
||||
("PDU cache capacity", &self.pdu_cache_capacity.to_string()),
|
||||
(
|
||||
"Cleanup interval in seconds",
|
||||
&self.cleanup_second_interval.to_string(),
|
||||
),
|
||||
("Cleanup interval in seconds", &self.cleanup_second_interval.to_string()),
|
||||
("Maximum request size (bytes)", &self.max_request_size.to_string()),
|
||||
(
|
||||
"Maximum concurrent requests",
|
||||
&self.max_concurrent_requests.to_string(),
|
||||
),
|
||||
(
|
||||
"Allow registration",
|
||||
&self.allow_registration.to_string(),
|
||||
),
|
||||
("Maximum concurrent requests", &self.max_concurrent_requests.to_string()),
|
||||
("Allow registration", &self.allow_registration.to_string()),
|
||||
(
|
||||
"Registration token",
|
||||
match self.registration_token {
|
||||
|
@ -271,10 +257,7 @@ impl fmt::Display for Config {
|
|||
"Allow guest registration (inherently false if allow registration is false)",
|
||||
&self.allow_guest_registration.to_string(),
|
||||
),
|
||||
(
|
||||
"New user display name suffix",
|
||||
&self.new_user_displayname_suffix,
|
||||
),
|
||||
("New user display name suffix", &self.new_user_displayname_suffix),
|
||||
("Allow encryption", &self.allow_encryption.to_string()),
|
||||
("Allow federation", &self.allow_federation.to_string()),
|
||||
(
|
||||
|
@ -293,10 +276,7 @@ impl fmt::Display for Config {
|
|||
"Block non-admin room invites (local and remote, admins can still send and receive invites)",
|
||||
&self.block_non_admin_invites.to_string(),
|
||||
),
|
||||
(
|
||||
"Allow device name federation",
|
||||
&self.allow_device_name_federation.to_string(),
|
||||
),
|
||||
("Allow device name federation", &self.allow_device_name_federation.to_string()),
|
||||
("Notification push path", &self.notification_push_path),
|
||||
("Allow room creation", &self.allow_room_creation.to_string()),
|
||||
(
|
||||
|
@ -356,15 +336,9 @@ impl fmt::Display for Config {
|
|||
}
|
||||
&lst.join(", ")
|
||||
}),
|
||||
(
|
||||
"zstd Response Body Compression",
|
||||
&self.zstd_compression.to_string(),
|
||||
),
|
||||
("zstd Response Body Compression", &self.zstd_compression.to_string()),
|
||||
("RocksDB database log level", &self.rocksdb_log_level),
|
||||
(
|
||||
"RocksDB database log time-to-roll",
|
||||
&self.rocksdb_log_time_to_roll.to_string(),
|
||||
),
|
||||
("RocksDB database log time-to-roll", &self.rocksdb_log_time_to_roll.to_string()),
|
||||
(
|
||||
"RocksDB database max log file size",
|
||||
&self.rocksdb_max_log_file_size.to_string(),
|
||||
|
@ -373,10 +347,7 @@ impl fmt::Display for Config {
|
|||
"RocksDB database optimize for spinning disks",
|
||||
&self.rocksdb_optimize_for_spinning_disks.to_string(),
|
||||
),
|
||||
(
|
||||
"RocksDB Parallelism Threads",
|
||||
&self.rocksdb_parallelism_threads.to_string(),
|
||||
),
|
||||
("RocksDB Parallelism Threads", &self.rocksdb_parallelism_threads.to_string()),
|
||||
("Prevent Media Downloads From", {
|
||||
let mut lst = vec![];
|
||||
for domain in &self.prevent_media_downloads_from {
|
||||
|
@ -410,14 +381,8 @@ impl fmt::Display for Config {
|
|||
"URL preview URL contains allowlist",
|
||||
&self.url_preview_url_contains_allowlist.join(", "),
|
||||
),
|
||||
(
|
||||
"URL preview maximum spider size",
|
||||
&self.url_preview_max_spider_size.to_string(),
|
||||
),
|
||||
(
|
||||
"URL preview check root domain",
|
||||
&self.url_preview_check_root_domain.to_string(),
|
||||
),
|
||||
("URL preview maximum spider size", &self.url_preview_max_spider_size.to_string()),
|
||||
("URL preview check root domain", &self.url_preview_check_root_domain.to_string()),
|
||||
];
|
||||
|
||||
let mut msg: String = "Active config values:\n\n".to_owned();
|
||||
|
@ -430,13 +395,9 @@ impl fmt::Display for Config {
|
|||
}
|
||||
}
|
||||
|
||||
fn true_fn() -> bool {
|
||||
true
|
||||
}
|
||||
fn true_fn() -> bool { true }
|
||||
|
||||
fn default_address() -> IpAddr {
|
||||
Ipv4Addr::LOCALHOST.into()
|
||||
}
|
||||
fn default_address() -> IpAddr { Ipv4Addr::LOCALHOST.into() }
|
||||
|
||||
fn default_port() -> ListeningPort {
|
||||
ListeningPort {
|
||||
|
@ -444,25 +405,15 @@ fn default_port() -> ListeningPort {
|
|||
}
|
||||
}
|
||||
|
||||
fn default_unix_socket_perms() -> u32 {
|
||||
660
|
||||
}
|
||||
fn default_unix_socket_perms() -> u32 { 660 }
|
||||
|
||||
fn default_database_backend() -> String {
|
||||
"rocksdb".to_owned()
|
||||
}
|
||||
fn default_database_backend() -> String { "rocksdb".to_owned() }
|
||||
|
||||
fn default_db_cache_capacity_mb() -> f64 {
|
||||
300.0
|
||||
}
|
||||
fn default_db_cache_capacity_mb() -> f64 { 300.0 }
|
||||
|
||||
fn default_conduit_cache_capacity_modifier() -> f64 {
|
||||
1.0
|
||||
}
|
||||
fn default_conduit_cache_capacity_modifier() -> f64 { 1.0 }
|
||||
|
||||
fn default_pdu_cache_capacity() -> u32 {
|
||||
150_000
|
||||
}
|
||||
fn default_pdu_cache_capacity() -> u32 { 150_000 }
|
||||
|
||||
fn default_cleanup_second_interval() -> u32 {
|
||||
60 // every minute
|
||||
|
@ -472,54 +423,30 @@ fn default_max_request_size() -> u32 {
|
|||
20 * 1024 * 1024 // Default to 20 MB
|
||||
}
|
||||
|
||||
fn default_max_concurrent_requests() -> u16 {
|
||||
500
|
||||
}
|
||||
fn default_max_concurrent_requests() -> u16 { 500 }
|
||||
|
||||
fn default_max_fetch_prev_events() -> u16 {
|
||||
100_u16
|
||||
}
|
||||
fn default_max_fetch_prev_events() -> u16 { 100_u16 }
|
||||
|
||||
fn default_trusted_servers() -> Vec<OwnedServerName> {
|
||||
vec![OwnedServerName::try_from("matrix.org").unwrap()]
|
||||
}
|
||||
fn default_trusted_servers() -> Vec<OwnedServerName> { vec![OwnedServerName::try_from("matrix.org").unwrap()] }
|
||||
|
||||
fn default_log() -> String {
|
||||
"warn,state_res=warn".to_owned()
|
||||
}
|
||||
fn default_log() -> String { "warn,state_res=warn".to_owned() }
|
||||
|
||||
fn default_notification_push_path() -> String {
|
||||
"/_matrix/push/v1/notify".to_owned()
|
||||
}
|
||||
fn default_notification_push_path() -> String { "/_matrix/push/v1/notify".to_owned() }
|
||||
|
||||
fn default_turn_ttl() -> u64 {
|
||||
60 * 60 * 24
|
||||
}
|
||||
fn default_turn_ttl() -> u64 { 60 * 60 * 24 }
|
||||
|
||||
fn default_presence_idle_timeout_s() -> u64 {
|
||||
5 * 60
|
||||
}
|
||||
fn default_presence_idle_timeout_s() -> u64 { 5 * 60 }
|
||||
|
||||
fn default_presence_offline_timeout_s() -> u64 {
|
||||
30 * 60
|
||||
}
|
||||
fn default_presence_offline_timeout_s() -> u64 { 30 * 60 }
|
||||
|
||||
fn default_rocksdb_log_level() -> String {
|
||||
"warn".to_owned()
|
||||
}
|
||||
fn default_rocksdb_log_level() -> String { "warn".to_owned() }
|
||||
|
||||
fn default_rocksdb_log_time_to_roll() -> usize {
|
||||
0
|
||||
}
|
||||
fn default_rocksdb_log_time_to_roll() -> usize { 0 }
|
||||
|
||||
fn default_rocksdb_parallelism_threads() -> usize {
|
||||
num_cpus::get_physical() / 2
|
||||
}
|
||||
fn default_rocksdb_parallelism_threads() -> usize { num_cpus::get_physical() / 2 }
|
||||
|
||||
// I know, it's a great name
|
||||
pub(crate) fn default_default_room_version() -> RoomVersionId {
|
||||
RoomVersionId::V10
|
||||
}
|
||||
pub(crate) fn default_default_room_version() -> RoomVersionId { RoomVersionId::V10 }
|
||||
|
||||
fn default_rocksdb_max_log_file_size() -> usize {
|
||||
// 4 megabytes
|
||||
|
@ -554,6 +481,4 @@ fn default_url_preview_max_spider_size() -> usize {
|
|||
1_000_000 // 1MB
|
||||
}
|
||||
|
||||
fn default_new_user_displayname_suffix() -> String {
|
||||
"🏳️⚧️".to_owned()
|
||||
}
|
||||
fn default_new_user_displayname_suffix() -> String { "🏳️⚧️".to_owned() }
|
||||
|
|
|
@ -24,9 +24,10 @@ use crate::Result;
|
|||
/// ## Include vs. Exclude
|
||||
/// If include is an empty list, it is assumed to be `["*"]`.
|
||||
///
|
||||
/// If a domain matches both the exclude and include list, the proxy will only be used if it was
|
||||
/// included because of a more specific rule than it was excluded. In the above example, the proxy
|
||||
/// would be used for `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
|
||||
/// If a domain matches both the exclude and include list, the proxy will only
|
||||
/// be used if it was included because of a more specific rule than it was
|
||||
/// excluded. In the above example, the proxy would be used for
|
||||
/// `ordinary.onion`, `matrix.myspecial.onion`, but not `hello.myspecial.onion`.
|
||||
#[derive(Clone, Default, Debug, Deserialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum ProxyConfig {
|
||||
|
@ -42,9 +43,12 @@ impl ProxyConfig {
|
|||
pub fn to_proxy(&self) -> Result<Option<Proxy>> {
|
||||
Ok(match self.clone() {
|
||||
ProxyConfig::None => None,
|
||||
ProxyConfig::Global { url } => Some(Proxy::all(url)?),
|
||||
ProxyConfig::Global {
|
||||
url,
|
||||
} => Some(Proxy::all(url)?),
|
||||
ProxyConfig::ByDomain(proxies) => Some(Proxy::custom(move |url| {
|
||||
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching proxy
|
||||
proxies.iter().find_map(|proxy| proxy.for_url(url)).cloned() // first matching
|
||||
// proxy
|
||||
})),
|
||||
})
|
||||
}
|
||||
|
@ -85,7 +89,8 @@ impl PartialProxyConfig {
|
|||
}
|
||||
}
|
||||
match (included_because, excluded_because) {
|
||||
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), // included for a more specific reason than excluded
|
||||
(Some(a), Some(b)) if a.more_specific_than(b) => Some(&self.url), /* included for a more specific reason */
|
||||
// than excluded
|
||||
(Some(_), None) => Some(&self.url),
|
||||
_ => None,
|
||||
}
|
||||
|
@ -107,20 +112,20 @@ impl WildCardedDomain {
|
|||
WildCardedDomain::Exact(d) => domain == d,
|
||||
}
|
||||
}
|
||||
|
||||
fn more_specific_than(&self, other: &Self) -> bool {
|
||||
match (self, other) {
|
||||
(WildCardedDomain::WildCard, WildCardedDomain::WildCard) => false,
|
||||
(_, WildCardedDomain::WildCard) => true,
|
||||
(WildCardedDomain::Exact(a), WildCardedDomain::WildCarded(_)) => other.matches(a),
|
||||
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => {
|
||||
a != b && a.ends_with(b)
|
||||
}
|
||||
(WildCardedDomain::WildCarded(a), WildCardedDomain::WildCarded(b)) => a != b && a.ends_with(b),
|
||||
_ => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
impl std::str::FromStr for WildCardedDomain {
|
||||
type Err = std::convert::Infallible;
|
||||
|
||||
fn from_str(s: &str) -> Result<Self, Self::Err> {
|
||||
// maybe do some domain validation?
|
||||
Ok(if s.starts_with("*.") {
|
||||
|
|
|
@ -1,8 +1,8 @@
|
|||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
use super::Config;
|
||||
use crate::Result;
|
||||
|
||||
use std::{future::Future, pin::Pin, sync::Arc};
|
||||
|
||||
#[cfg(feature = "sqlite")]
|
||||
pub mod sqlite;
|
||||
|
||||
|
@ -18,9 +18,7 @@ pub(crate) trait KeyValueDatabaseEngine: Send + Sync {
|
|||
Self: Sized;
|
||||
fn open_tree(&self, name: &'static str) -> Result<Arc<dyn KvTree>>;
|
||||
fn flush(&self) -> Result<()>;
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
Ok(())
|
||||
}
|
||||
fn cleanup(&self) -> Result<()> { Ok(()) }
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
Ok("Current database engine does not support memory usage reporting.".to_owned())
|
||||
}
|
||||
|
@ -39,19 +37,12 @@ pub(crate) trait KvTree: Send + Sync {
|
|||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
|
||||
fn increment(&self, key: &[u8]) -> Result<Vec<u8>>;
|
||||
fn increment_batch(&self, iter: &mut dyn Iterator<Item = Vec<u8>>) -> Result<()>;
|
||||
|
||||
fn scan_prefix<'a>(
|
||||
&'a self,
|
||||
prefix: Vec<u8>,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a>;
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>>;
|
||||
|
||||
|
|
|
@ -7,9 +7,8 @@ use std::{
|
|||
use rocksdb::LogLevel::{Debug, Error, Fatal, Info, Warn};
|
||||
use tracing::{debug, info};
|
||||
|
||||
use crate::{utils, Result};
|
||||
|
||||
use super::{super::Config, watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{utils, Result};
|
||||
|
||||
pub(crate) struct Engine {
|
||||
rocks: rocksdb::DBWithThreadMode<rocksdb::MultiThreaded>,
|
||||
|
@ -62,7 +61,9 @@ fn db_options(rocksdb_cache: &rocksdb::Cache, config: &Config) -> rocksdb::Optio
|
|||
db_opts.set_compaction_readahead_size(2 * 1024 * 1024); // default compaction_readahead_size is 0 which is good for SSDs
|
||||
db_opts.set_target_file_size_base(256 * 1024 * 1024); // default target_file_size is 64MB which is good for SSDs
|
||||
db_opts.set_optimize_filters_for_hits(true); // doesn't really seem useful for fast storage
|
||||
db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for spinning hard drives. these are not really important
|
||||
db_opts.set_keep_log_file_num(3); // keep as few LOG files as possible for
|
||||
// spinning hard drives. these are not really
|
||||
// important
|
||||
} else {
|
||||
db_opts.set_skip_stats_update_on_db_open(false);
|
||||
db_opts.set_max_bytes_for_level_base(512 * 1024 * 1024);
|
||||
|
@ -75,9 +76,7 @@ fn db_options(rocksdb_cache: &rocksdb::Cache, config: &Config) -> rocksdb::Optio
|
|||
db_opts.set_level_compaction_dynamic_level_bytes(true);
|
||||
db_opts.create_if_missing(true);
|
||||
db_opts.increase_parallelism(
|
||||
threads
|
||||
.try_into()
|
||||
.expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"),
|
||||
threads.try_into().expect("Failed to convert \"rocksdb_parallelism_threads\" usize into i32"),
|
||||
);
|
||||
//db_opts.set_max_open_files(config.rocksdb_max_open_files);
|
||||
db_opts.set_compression_type(rocksdb::DBCompressionType::Zstd);
|
||||
|
@ -109,10 +108,7 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
|||
let db_opts = db_options(&rocksdb_cache, config);
|
||||
|
||||
debug!("Listing column families in database");
|
||||
let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf(
|
||||
&db_opts,
|
||||
&config.database_path,
|
||||
)
|
||||
let cfs = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::list_cf(&db_opts, &config.database_path)
|
||||
.unwrap_or_default();
|
||||
|
||||
debug!("Opening column family descriptors in database");
|
||||
|
@ -120,9 +116,7 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
|||
let db = rocksdb::DBWithThreadMode::<rocksdb::MultiThreaded>::open_cf_descriptors(
|
||||
&db_opts,
|
||||
&config.database_path,
|
||||
cfs.iter().map(|name| {
|
||||
rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config))
|
||||
}),
|
||||
cfs.iter().map(|name| rocksdb::ColumnFamilyDescriptor::new(name, db_options(&rocksdb_cache, config))),
|
||||
)?;
|
||||
|
||||
Ok(Arc::new(Engine {
|
||||
|
@ -137,9 +131,7 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
|||
if !self.old_cfs.contains(&name.to_owned()) {
|
||||
// Create if it didn't exist
|
||||
debug!("Creating new column family in database: {}", name);
|
||||
let _ = self
|
||||
.rocks
|
||||
.create_cf(name, &db_options(&self.cache, &self.config));
|
||||
let _ = self.rocks.create_cf(name, &db_options(&self.cache, &self.config));
|
||||
}
|
||||
|
||||
Ok(Arc::new(RocksDbEngineTree {
|
||||
|
@ -156,15 +148,11 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
|||
}
|
||||
|
||||
fn memory_usage(&self) -> Result<String> {
|
||||
let stats =
|
||||
rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
|
||||
let stats = rocksdb::perf::get_memory_usage_stats(Some(&[&self.rocks]), Some(&[&self.cache]))?;
|
||||
Ok(format!(
|
||||
"Approximate memory usage of all the mem-tables: {:.3} MB\n\
|
||||
Approximate memory usage of un-flushed mem-tables: {:.3} MB\n\
|
||||
Approximate memory usage of all the table readers: {:.3} MB\n\
|
||||
Approximate memory usage by cache: {:.3} MB\n\
|
||||
Approximate memory usage by cache pinned: {:.3} MB\n\
|
||||
",
|
||||
"Approximate memory usage of all the mem-tables: {:.3} MB\nApproximate memory usage of un-flushed \
|
||||
mem-tables: {:.3} MB\nApproximate memory usage of all the table readers: {:.3} MB\nApproximate memory \
|
||||
usage by cache: {:.3} MB\nApproximate memory usage by cache pinned: {:.3} MB\n",
|
||||
stats.mem_table_total as f64 / 1024.0 / 1024.0,
|
||||
stats.mem_table_unflushed as f64 / 1024.0 / 1024.0,
|
||||
stats.mem_table_readers_total as f64 / 1024.0 / 1024.0,
|
||||
|
@ -179,15 +167,11 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
|||
}
|
||||
|
||||
impl RocksDbEngineTree<'_> {
|
||||
fn cf(&self) -> Arc<rocksdb::BoundColumnFamily<'_>> {
|
||||
self.db.rocks.cf_handle(self.name).unwrap()
|
||||
}
|
||||
fn cf(&self) -> Arc<rocksdb::BoundColumnFamily<'_>> { self.db.rocks.cf_handle(self.name).unwrap() }
|
||||
}
|
||||
|
||||
impl KvTree for RocksDbEngineTree<'_> {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
Ok(self.db.rocks.get_cf(&self.cf(), key)?)
|
||||
}
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { Ok(self.db.rocks.get_cf(&self.cf(), key)?) }
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let lock = self.write_lock.read().unwrap();
|
||||
|
@ -207,9 +191,7 @@ impl KvTree for RocksDbEngineTree<'_> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||
Ok(self.db.rocks.delete_cf(&self.cf(), key)?)
|
||||
}
|
||||
fn remove(&self, key: &[u8]) -> Result<()> { Ok(self.db.rocks.delete_cf(&self.cf(), key)?) }
|
||||
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
|
@ -221,11 +203,7 @@ impl KvTree for RocksDbEngineTree<'_> {
|
|||
)
|
||||
}
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
|
@ -270,17 +248,11 @@ impl KvTree for RocksDbEngineTree<'_> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn scan_prefix<'a>(
|
||||
&'a self,
|
||||
prefix: Vec<u8>,
|
||||
) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = (Vec<u8>, Vec<u8>)> + 'a> {
|
||||
Box::new(
|
||||
self.db
|
||||
.rocks
|
||||
.iterator_cf(
|
||||
&self.cf(),
|
||||
rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward),
|
||||
)
|
||||
.iterator_cf(&self.cf(), rocksdb::IteratorMode::From(&prefix, rocksdb::Direction::Forward))
|
||||
.map(std::result::Result::unwrap)
|
||||
.map(|(k, v)| (Vec::from(k), Vec::from(v)))
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix)),
|
||||
|
|
|
@ -1,7 +1,3 @@
|
|||
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{database::Config, Result};
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
|
||||
use std::{
|
||||
cell::RefCell,
|
||||
future::Future,
|
||||
|
@ -9,9 +5,15 @@ use std::{
|
|||
pin::Pin,
|
||||
sync::Arc,
|
||||
};
|
||||
|
||||
use parking_lot::{Mutex, MutexGuard};
|
||||
use rusqlite::{Connection, DatabaseName::Main, OptionalExtension};
|
||||
use thread_local::ThreadLocal;
|
||||
use tracing::debug;
|
||||
|
||||
use super::{watchers::Watchers, KeyValueDatabaseEngine, KvTree};
|
||||
use crate::{database::Config, Result};
|
||||
|
||||
thread_local! {
|
||||
static READ_CONNECTION: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||
static READ_CONNECTION_ITERATOR: RefCell<Option<&'static Connection>> = RefCell::new(None);
|
||||
|
@ -25,9 +27,7 @@ struct PreparedStatementIterator<'a> {
|
|||
impl Iterator for PreparedStatementIterator<'_> {
|
||||
type Item = TupleOfBytes;
|
||||
|
||||
fn next(&mut self) -> Option<Self::Item> {
|
||||
self.iterator.next()
|
||||
}
|
||||
fn next(&mut self) -> Option<Self::Item> { self.iterator.next() }
|
||||
}
|
||||
|
||||
struct NonAliasingBox<T>(*mut T);
|
||||
|
@ -61,23 +61,18 @@ impl Engine {
|
|||
Ok(conn)
|
||||
}
|
||||
|
||||
fn write_lock(&self) -> MutexGuard<'_, Connection> {
|
||||
self.writer.lock()
|
||||
}
|
||||
fn write_lock(&self) -> MutexGuard<'_, Connection> { self.writer.lock() }
|
||||
|
||||
fn read_lock(&self) -> &Connection {
|
||||
self.read_conn_tls
|
||||
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
self.read_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
}
|
||||
|
||||
fn read_lock_iterator(&self) -> &Connection {
|
||||
self.read_iterator_conn_tls
|
||||
.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
self.read_iterator_conn_tls.get_or(|| Self::prepare_conn(&self.path, self.cache_size_per_thread).unwrap())
|
||||
}
|
||||
|
||||
pub fn flush_wal(self: &Arc<Self>) -> Result<()> {
|
||||
self.write_lock()
|
||||
.pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
|
||||
self.write_lock().pragma_update(Some(Main), "wal_checkpoint", "RESTART")?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -88,11 +83,11 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
|||
|
||||
// calculates cache-size per permanent connection
|
||||
// 1. convert MB to KiB
|
||||
// 2. divide by permanent connections + permanent iter connections + write connection
|
||||
// 2. divide by permanent connections + permanent iter connections + write
|
||||
// connection
|
||||
// 3. round down to nearest integer
|
||||
let cache_size_per_thread: u32 = ((config.db_cache_capacity_mb * 1024.0)
|
||||
/ ((num_cpus::get().max(1) * 2) + 1) as f64)
|
||||
as u32;
|
||||
let cache_size_per_thread: u32 =
|
||||
((config.db_cache_capacity_mb * 1024.0) / ((num_cpus::get().max(1) * 2) + 1) as f64) as u32;
|
||||
|
||||
let writer = Mutex::new(Engine::prepare_conn(&path, cache_size_per_thread)?);
|
||||
|
||||
|
@ -108,7 +103,10 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
|||
}
|
||||
|
||||
fn open_tree(&self, name: &str) -> Result<Arc<dyn KvTree>> {
|
||||
self.write_lock().execute(&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"), [])?;
|
||||
self.write_lock().execute(
|
||||
&format!("CREATE TABLE IF NOT EXISTS {name} ( \"key\" BLOB PRIMARY KEY, \"value\" BLOB NOT NULL )"),
|
||||
[],
|
||||
)?;
|
||||
|
||||
Ok(Arc::new(SqliteTable {
|
||||
engine: Arc::clone(self),
|
||||
|
@ -122,9 +120,7 @@ impl KeyValueDatabaseEngine for Arc<Engine> {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
self.flush_wal()
|
||||
}
|
||||
fn cleanup(&self) -> Result<()> { self.flush_wal() }
|
||||
}
|
||||
|
||||
pub struct SqliteTable {
|
||||
|
@ -145,27 +141,15 @@ impl SqliteTable {
|
|||
|
||||
fn insert_with_guard(&self, guard: &Connection, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
guard.execute(
|
||||
format!(
|
||||
"INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)",
|
||||
self.name
|
||||
)
|
||||
.as_str(),
|
||||
format!("INSERT OR REPLACE INTO {} (key, value) VALUES (?, ?)", self.name).as_str(),
|
||||
[key, value],
|
||||
)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub fn iter_with_guard<'a>(
|
||||
&'a self,
|
||||
guard: &'a Connection,
|
||||
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
pub fn iter_with_guard<'a>(&'a self, guard: &'a Connection) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let statement = Box::leak(Box::new(
|
||||
guard
|
||||
.prepare(&format!(
|
||||
"SELECT key, value FROM {} ORDER BY key ASC",
|
||||
&self.name
|
||||
))
|
||||
.unwrap(),
|
||||
guard.prepare(&format!("SELECT key, value FROM {} ORDER BY key ASC", &self.name)).unwrap(),
|
||||
));
|
||||
|
||||
let statement_ref = NonAliasingBox(statement);
|
||||
|
@ -173,10 +157,7 @@ impl SqliteTable {
|
|||
//let name = self.name.clone();
|
||||
|
||||
let iterator = Box::new(
|
||||
statement
|
||||
.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1))))
|
||||
.unwrap()
|
||||
.map(move |r| r.unwrap()),
|
||||
statement.query_map([], |row| Ok((row.get_unwrap(0), row.get_unwrap(1)))).unwrap().map(move |r| r.unwrap()),
|
||||
);
|
||||
|
||||
Box::new(PreparedStatementIterator {
|
||||
|
@ -187,9 +168,7 @@ impl SqliteTable {
|
|||
}
|
||||
|
||||
impl KvTree for SqliteTable {
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> {
|
||||
self.get_with_guard(self.engine.read_lock(), key)
|
||||
}
|
||||
fn get(&self, key: &[u8]) -> Result<Option<Vec<u8>>> { self.get_with_guard(self.engine.read_lock(), key) }
|
||||
|
||||
fn insert(&self, key: &[u8], value: &[u8]) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
|
@ -219,8 +198,7 @@ impl KvTree for SqliteTable {
|
|||
guard.execute("BEGIN", [])?;
|
||||
for key in iter {
|
||||
let old = self.get_with_guard(&guard, &key)?;
|
||||
let new = crate::utils::increment(old.as_deref())
|
||||
.expect("utils::increment always returns Some");
|
||||
let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||
self.insert_with_guard(&guard, &key, &new)?;
|
||||
}
|
||||
guard.execute("COMMIT", [])?;
|
||||
|
@ -233,10 +211,7 @@ impl KvTree for SqliteTable {
|
|||
fn remove(&self, key: &[u8]) -> Result<()> {
|
||||
let guard = self.engine.write_lock();
|
||||
|
||||
guard.execute(
|
||||
format!("DELETE FROM {} WHERE key = ?", self.name).as_str(),
|
||||
[key],
|
||||
)?;
|
||||
guard.execute(format!("DELETE FROM {} WHERE key = ?", self.name).as_str(), [key])?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -247,11 +222,7 @@ impl KvTree for SqliteTable {
|
|||
self.iter_with_guard(guard)
|
||||
}
|
||||
|
||||
fn iter_from<'a>(
|
||||
&'a self,
|
||||
from: &[u8],
|
||||
backwards: bool,
|
||||
) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
fn iter_from<'a>(&'a self, from: &[u8], backwards: bool) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
let guard = self.engine.read_lock_iterator();
|
||||
let from = from.to_vec(); // TODO change interface?
|
||||
|
||||
|
@ -310,8 +281,7 @@ impl KvTree for SqliteTable {
|
|||
|
||||
let old = self.get_with_guard(&guard, key)?;
|
||||
|
||||
let new =
|
||||
crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||
let new = crate::utils::increment(old.as_deref()).expect("utils::increment always returns Some");
|
||||
|
||||
self.insert_with_guard(&guard, key, &new)?;
|
||||
|
||||
|
@ -319,10 +289,7 @@ impl KvTree for SqliteTable {
|
|||
}
|
||||
|
||||
fn scan_prefix<'a>(&'a self, prefix: Vec<u8>) -> Box<dyn Iterator<Item = TupleOfBytes> + 'a> {
|
||||
Box::new(
|
||||
self.iter_from(&prefix, false)
|
||||
.take_while(move |(key, _)| key.starts_with(&prefix)),
|
||||
)
|
||||
Box::new(self.iter_from(&prefix, false).take_while(move |(key, _)| key.starts_with(&prefix)))
|
||||
}
|
||||
|
||||
fn watch_prefix<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
|
@ -331,9 +298,7 @@ impl KvTree for SqliteTable {
|
|||
|
||||
fn clear(&self) -> Result<()> {
|
||||
debug!("clear: running");
|
||||
self.engine
|
||||
.write_lock()
|
||||
.execute(format!("DELETE FROM {}", self.name).as_str(), [])?;
|
||||
self.engine.write_lock().execute(format!("DELETE FROM {}", self.name).as_str(), [])?;
|
||||
debug!("clear: ran");
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -4,6 +4,7 @@ use std::{
|
|||
pin::Pin,
|
||||
sync::RwLock,
|
||||
};
|
||||
|
||||
use tokio::sync::watch;
|
||||
|
||||
type Watcher = RwLock<HashMap<Vec<u8>, (watch::Sender<()>, watch::Receiver<()>)>>;
|
||||
|
@ -14,17 +15,14 @@ pub(super) struct Watchers {
|
|||
}
|
||||
|
||||
impl Watchers {
|
||||
pub(super) fn watch<'a>(
|
||||
&'a self,
|
||||
prefix: &[u8],
|
||||
) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
pub(super) fn watch<'a>(&'a self, prefix: &[u8]) -> Pin<Box<dyn Future<Output = ()> + Send + 'a>> {
|
||||
let mut rx = match self.watchers.write().unwrap().entry(prefix.to_vec()) {
|
||||
hash_map::Entry::Occupied(o) => o.get().1.clone(),
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
let (tx, rx) = tokio::sync::watch::channel(());
|
||||
v.insert((tx, rx.clone()));
|
||||
rx
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
Box::pin(async move {
|
||||
|
@ -32,6 +30,7 @@ impl Watchers {
|
|||
rx.changed().await.unwrap();
|
||||
})
|
||||
}
|
||||
|
||||
pub(super) fn wake(&self, key: &[u8]) {
|
||||
let watchers = self.watchers.read().unwrap();
|
||||
let mut triggered = Vec::new();
|
||||
|
|
|
@ -11,27 +11,21 @@ use tracing::warn;
|
|||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::account_data::Data for KeyValueDatabase {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
/// Places one event in the account data of the user and removes the
|
||||
/// previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||
fn update(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id
|
||||
.map(std::string::ToString::to_string)
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut roomuserdataid = prefix.clone();
|
||||
roomuserdataid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
roomuserdataid.push(0xff);
|
||||
roomuserdataid.push(0xFF);
|
||||
roomuserdataid.extend_from_slice(event_type.to_string().as_bytes());
|
||||
|
||||
let mut key = prefix;
|
||||
|
@ -51,8 +45,7 @@ impl service::account_data::Data for KeyValueDatabase {
|
|||
|
||||
let prev = self.roomusertype_roomuserdataid.get(&key)?;
|
||||
|
||||
self.roomusertype_roomuserdataid
|
||||
.insert(&key, &roomuserdataid)?;
|
||||
self.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
|
||||
|
||||
// Remove old entry
|
||||
if let Some(prev) = prev {
|
||||
|
@ -65,54 +58,33 @@ impl service::account_data::Data for KeyValueDatabase {
|
|||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, kind))]
|
||||
fn get(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
kind: RoomAccountDataEventType,
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||
let mut key = room_id
|
||||
.map(std::string::ToString::to_string)
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
key.push(0xff);
|
||||
let mut key = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(kind.to_string().as_bytes());
|
||||
|
||||
self.roomusertype_roomuserdataid
|
||||
.get(&key)?
|
||||
.and_then(|roomuserdataid| {
|
||||
self.roomuserdataid_accountdata
|
||||
.get(&roomuserdataid)
|
||||
.transpose()
|
||||
})
|
||||
.and_then(|roomuserdataid| self.roomuserdataid_accountdata.get(&roomuserdataid).transpose())
|
||||
.transpose()?
|
||||
.map(|data| {
|
||||
serde_json::from_slice(&data)
|
||||
.map_err(|_| Error::bad_database("could not deserialize"))
|
||||
})
|
||||
.map(|data| serde_json::from_slice(&data).map_err(|_| Error::bad_database("could not deserialize")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
since: u64,
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||
let mut userdata = HashMap::new();
|
||||
|
||||
let mut prefix = room_id
|
||||
.map(std::string::ToString::to_string)
|
||||
.unwrap_or_default()
|
||||
.as_bytes()
|
||||
.to_vec();
|
||||
prefix.push(0xff);
|
||||
let mut prefix = room_id.map(std::string::ToString::to_string).unwrap_or_default().as_bytes().to_vec();
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(user_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
// Skip the data that's exactly at since, because we sent that last time
|
||||
let mut first_possible = prefix.clone();
|
||||
|
@ -125,20 +97,20 @@ impl service::account_data::Data for KeyValueDatabase {
|
|||
.map(|(k, v)| {
|
||||
Ok::<_, Error>((
|
||||
RoomAccountDataEventType::from(
|
||||
utils::string_from_bytes(k.rsplit(|&b| b == 0xff).next().ok_or_else(
|
||||
|| Error::bad_database("RoomUserData ID in db is invalid."),
|
||||
)?)
|
||||
utils::string_from_bytes(
|
||||
k.rsplit(|&b| b == 0xFF)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("RoomUserData ID in db is invalid."))?,
|
||||
)
|
||||
.map_err(|e| {
|
||||
warn!("RoomUserData ID in database is invalid: {}", e);
|
||||
Error::bad_database("RoomUserData ID in db is invalid.")
|
||||
})?,
|
||||
),
|
||||
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v).map_err(|_| {
|
||||
Error::bad_database("Database contains invalid account data.")
|
||||
})?,
|
||||
serde_json::from_slice::<Raw<AnyEphemeralRoomEvent>>(&v)
|
||||
.map_err(|_| Error::bad_database("Database contains invalid account data."))?,
|
||||
))
|
||||
})
|
||||
{
|
||||
}) {
|
||||
let (kind, data) = r?;
|
||||
userdata.insert(kind, data);
|
||||
}
|
||||
|
|
|
@ -6,14 +6,8 @@ impl service::appservice::Data for KeyValueDatabase {
|
|||
/// Registers an appservice and returns the ID to the caller
|
||||
fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
let id = yaml.id.as_str();
|
||||
self.id_appserviceregistrations.insert(
|
||||
id.as_bytes(),
|
||||
serde_yaml::to_string(&yaml).unwrap().as_bytes(),
|
||||
)?;
|
||||
self.cached_registrations
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(id.to_owned(), yaml.clone());
|
||||
self.id_appserviceregistrations.insert(id.as_bytes(), serde_yaml::to_string(&yaml).unwrap().as_bytes())?;
|
||||
self.cached_registrations.write().unwrap().insert(id.to_owned(), yaml.clone());
|
||||
|
||||
Ok(id.to_owned())
|
||||
}
|
||||
|
@ -24,29 +18,19 @@ impl service::appservice::Data for KeyValueDatabase {
|
|||
///
|
||||
/// * `service_name` - the name you send to register the service previously
|
||||
fn unregister_appservice(&self, service_name: &str) -> Result<()> {
|
||||
self.id_appserviceregistrations
|
||||
.remove(service_name.as_bytes())?;
|
||||
self.cached_registrations
|
||||
.write()
|
||||
.unwrap()
|
||||
.remove(service_name);
|
||||
self.id_appserviceregistrations.remove(service_name.as_bytes())?;
|
||||
self.cached_registrations.write().unwrap().remove(service_name);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
||||
self.cached_registrations
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(id)
|
||||
.map_or_else(
|
||||
self.cached_registrations.read().unwrap().get(id).map_or_else(
|
||||
|| {
|
||||
self.id_appserviceregistrations
|
||||
.get(id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
serde_yaml::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Invalid registration bytes in id_appserviceregistrations.",
|
||||
)
|
||||
Error::bad_database("Invalid registration bytes in id_appserviceregistrations.")
|
||||
})
|
||||
})
|
||||
.transpose()
|
||||
|
@ -56,13 +40,10 @@ impl service::appservice::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Result<Box<dyn Iterator<Item = Result<String>> + 'a>> {
|
||||
Ok(Box::new(self.id_appserviceregistrations.iter().map(
|
||||
|(id, _)| {
|
||||
utils::string_from_bytes(&id).map_err(|_| {
|
||||
Error::bad_database("Invalid id bytes in id_appserviceregistrations.")
|
||||
})
|
||||
},
|
||||
)))
|
||||
Ok(Box::new(self.id_appserviceregistrations.iter().map(|(id, _)| {
|
||||
utils::string_from_bytes(&id)
|
||||
.map_err(|_| Error::bad_database("Invalid id bytes in id_appserviceregistrations."))
|
||||
})))
|
||||
}
|
||||
|
||||
fn all(&self) -> Result<Vec<(String, Registration)>> {
|
||||
|
@ -71,8 +52,7 @@ impl service::appservice::Data for KeyValueDatabase {
|
|||
.map(move |id| {
|
||||
Ok((
|
||||
id.clone(),
|
||||
self.get_registration(&id)?
|
||||
.expect("iter_ids only returns appservices that exist"),
|
||||
self.get_registration(&id)?.expect("iter_ids only returns appservices that exist"),
|
||||
))
|
||||
})
|
||||
.collect()
|
||||
|
|
|
@ -23,24 +23,19 @@ impl service::globals::Data for KeyValueDatabase {
|
|||
|
||||
fn current_count(&self) -> Result<u64> {
|
||||
self.global.get(COUNTER)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
|
||||
fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||
self.global
|
||||
.get(LAST_CHECK_FOR_UPDATES_COUNT)?
|
||||
.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("last check for updates count has invalid bytes.")
|
||||
})
|
||||
self.global.get(LAST_CHECK_FOR_UPDATES_COUNT)?.map_or(Ok(0_u64), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("last check for updates count has invalid bytes."))
|
||||
})
|
||||
}
|
||||
|
||||
fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
||||
self.global
|
||||
.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
|
||||
self.global.insert(LAST_CHECK_FOR_UPDATES_COUNT, &id.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -48,11 +43,11 @@ impl service::globals::Data for KeyValueDatabase {
|
|||
async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
let userid_bytes = user_id.as_bytes().to_vec();
|
||||
let mut userid_prefix = userid_bytes.clone();
|
||||
userid_prefix.push(0xff);
|
||||
userid_prefix.push(0xFF);
|
||||
|
||||
let mut userdeviceid_prefix = userid_prefix.clone();
|
||||
userdeviceid_prefix.extend_from_slice(device_id.as_bytes());
|
||||
userdeviceid_prefix.push(0xff);
|
||||
userdeviceid_prefix.push(0xFF);
|
||||
|
||||
let mut futures = FuturesUnordered::new();
|
||||
|
||||
|
@ -63,19 +58,11 @@ impl service::globals::Data for KeyValueDatabase {
|
|||
futures.push(self.userroomid_joined.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_invitestate.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_leftstate.watch_prefix(&userid_prefix));
|
||||
futures.push(
|
||||
self.userroomid_notificationcount
|
||||
.watch_prefix(&userid_prefix),
|
||||
);
|
||||
futures.push(self.userroomid_notificationcount.watch_prefix(&userid_prefix));
|
||||
futures.push(self.userroomid_highlightcount.watch_prefix(&userid_prefix));
|
||||
|
||||
// Events for rooms we are in
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(user_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(std::result::Result::ok) {
|
||||
let short_roomid = services()
|
||||
.rooms
|
||||
.short
|
||||
|
@ -88,7 +75,7 @@ impl service::globals::Data for KeyValueDatabase {
|
|||
|
||||
let roomid_bytes = room_id.as_bytes().to_vec();
|
||||
let mut roomid_prefix = roomid_bytes.clone();
|
||||
roomid_prefix.push(0xff);
|
||||
roomid_prefix.push(0xFF);
|
||||
|
||||
// PDUs
|
||||
futures.push(self.pduid_pdu.watch_prefix(&short_roomid));
|
||||
|
@ -105,19 +92,13 @@ impl service::globals::Data for KeyValueDatabase {
|
|||
let mut roomuser_prefix = roomid_prefix.clone();
|
||||
roomuser_prefix.extend_from_slice(&userid_prefix);
|
||||
|
||||
futures.push(
|
||||
self.roomusertype_roomuserdataid
|
||||
.watch_prefix(&roomuser_prefix),
|
||||
);
|
||||
futures.push(self.roomusertype_roomuserdataid.watch_prefix(&roomuser_prefix));
|
||||
}
|
||||
|
||||
let mut globaluserdata_prefix = vec![0xff];
|
||||
let mut globaluserdata_prefix = vec![0xFF];
|
||||
globaluserdata_prefix.extend_from_slice(&userid_prefix);
|
||||
|
||||
futures.push(
|
||||
self.roomusertype_roomuserdataid
|
||||
.watch_prefix(&globaluserdata_prefix),
|
||||
);
|
||||
futures.push(self.roomusertype_roomuserdataid.watch_prefix(&globaluserdata_prefix));
|
||||
|
||||
// More key changes (used when user is not joined to any rooms)
|
||||
futures.push(self.keychangeid_userid.watch_prefix(&userid_prefix));
|
||||
|
@ -133,9 +114,7 @@ impl service::globals::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn cleanup(&self) -> Result<()> {
|
||||
self.db.cleanup()
|
||||
}
|
||||
fn cleanup(&self) -> Result<()> { self.db.cleanup() }
|
||||
|
||||
fn memory_usage(&self) -> String {
|
||||
let pdu_cache = self.pdu_cache.lock().unwrap().len();
|
||||
|
@ -210,13 +189,11 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|
|||
Ok,
|
||||
)?;
|
||||
|
||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xff);
|
||||
let mut parts = keypair_bytes.splitn(2, |&b| b == 0xFF);
|
||||
|
||||
utils::string_from_bytes(
|
||||
// 1. version
|
||||
parts
|
||||
.next()
|
||||
.expect("splitn always returns at least one element"),
|
||||
parts.next().expect("splitn always returns at least one element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid version bytes in keypair."))
|
||||
.and_then(|version| {
|
||||
|
@ -231,21 +208,16 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|
|||
.map_err(|_| Error::bad_database("Private or public keys are invalid."))
|
||||
})
|
||||
}
|
||||
fn remove_keypair(&self) -> Result<()> {
|
||||
self.global.remove(b"keypair")
|
||||
}
|
||||
|
||||
fn remove_keypair(&self) -> Result<()> { self.global.remove(b"keypair") }
|
||||
|
||||
fn add_signing_key(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
new_keys: ServerSigningKeys,
|
||||
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
// Not atomic, but this is not critical
|
||||
let signingkeys = self.server_signingkeys.get(origin.as_bytes())?;
|
||||
|
||||
let mut keys = signingkeys
|
||||
.and_then(|keys| serde_json::from_slice(&keys).ok())
|
||||
.unwrap_or_else(|| {
|
||||
let mut keys = signingkeys.and_then(|keys| serde_json::from_slice(&keys).ok()).unwrap_or_else(|| {
|
||||
// Just insert "now", it doesn't matter
|
||||
ServerSigningKeys::new(origin.to_owned(), MilliSecondsSinceUnixEpoch::now())
|
||||
});
|
||||
|
@ -265,31 +237,21 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|
|||
)?;
|
||||
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key))));
|
||||
|
||||
Ok(tree)
|
||||
}
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
||||
fn signing_keys_for(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||
/// for the server.
|
||||
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
let signingkeys = self
|
||||
.server_signingkeys
|
||||
.get(origin.as_bytes())?
|
||||
.and_then(|bytes| serde_json::from_slice(&bytes).ok())
|
||||
.map(|keys: ServerSigningKeys| {
|
||||
let mut tree = keys.verify_keys;
|
||||
tree.extend(
|
||||
keys.old_verify_keys
|
||||
.into_iter()
|
||||
.map(|old| (old.0, VerifyKey::new(old.1.key))),
|
||||
);
|
||||
tree.extend(keys.old_verify_keys.into_iter().map(|old| (old.0, VerifyKey::new(old.1.key))));
|
||||
tree
|
||||
})
|
||||
.unwrap_or_else(BTreeMap::new);
|
||||
|
@ -299,8 +261,7 @@ lasttimelinecount_cache: {lasttimelinecount_cache}\n"
|
|||
|
||||
fn database_version(&self) -> Result<u64> {
|
||||
self.global.get(b"version")?.map_or(Ok(0), |version| {
|
||||
utils::u64_from_bytes(&version)
|
||||
.map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
utils::u64_from_bytes(&version).map_err(|_| Error::bad_database("Database version id is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
|
|
|
@ -12,35 +12,30 @@ use ruma::{
|
|||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::key_backups::Data for KeyValueDatabase {
|
||||
fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||
let version = services().globals.next_count()?.to_string();
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(backup_metadata).expect("BackupAlgorithm::to_vec always works"),
|
||||
)?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version)
|
||||
}
|
||||
|
||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm.remove(&key)?;
|
||||
self.backupid_etag.remove(&key)?;
|
||||
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
|
@ -49,33 +44,23 @@ impl service::key_backups::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Tried to update nonexistent backup.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
|
||||
}
|
||||
|
||||
self.backupid_algorithm
|
||||
.insert(&key, backup_metadata.json().get().as_bytes())?;
|
||||
self.backupid_etag
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
self.backupid_algorithm.insert(&key, backup_metadata.json().get().as_bytes())?;
|
||||
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
Ok(version.to_owned())
|
||||
}
|
||||
|
||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
|
@ -84,22 +69,15 @@ impl service::key_backups::Data for KeyValueDatabase {
|
|||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.next()
|
||||
.map(|(key, _)| {
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_latest_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
||||
|
@ -109,17 +87,14 @@ impl service::key_backups::Data for KeyValueDatabase {
|
|||
.next()
|
||||
.map(|(key, value)| {
|
||||
let version = utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupid_algorithm key is invalid."))?;
|
||||
|
||||
Ok((
|
||||
version,
|
||||
serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("Algorithm in backupid_algorithm is invalid.")
|
||||
})?,
|
||||
serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))?,
|
||||
))
|
||||
})
|
||||
.transpose()
|
||||
|
@ -127,53 +102,41 @@ impl service::key_backups::Data for KeyValueDatabase {
|
|||
|
||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
self.backupid_algorithm
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
self.backupid_algorithm.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
serde_json::from_slice(&bytes)
|
||||
.map_err(|_| Error::bad_database("Algorithm in backupid_algorithm is invalid."))
|
||||
})
|
||||
}
|
||||
|
||||
fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
if self.backupid_algorithm.get(&key)?.is_none() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Tried to update nonexistent backup.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Tried to update nonexistent backup."));
|
||||
}
|
||||
|
||||
self.backupid_etag
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
self.backupid_etag.insert(&key, &services().globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
self.backupkeyid_backup
|
||||
.insert(&key, key_data.json().get().as_bytes())?;
|
||||
self.backupkeyid_backup.insert(&key, key_data.json().get().as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
|
||||
Ok(self.backupkeyid_backup.scan_prefix(prefix).count())
|
||||
|
@ -181,62 +144,45 @@ impl service::key_backups::Data for KeyValueDatabase {
|
|||
|
||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
|
||||
Ok(utils::u64_from_bytes(
|
||||
&self
|
||||
.backupid_etag
|
||||
.get(&key)?
|
||||
.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
|
||||
&self.backupid_etag.get(&key)?.ok_or_else(|| Error::bad_database("Backup has no etag."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("etag in backupid_etag invalid."))?
|
||||
.to_string())
|
||||
}
|
||||
|
||||
fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut rooms = BTreeMap::<OwnedRoomId, RoomKeyBackup>::new();
|
||||
|
||||
for result in self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
for result in self.backupkeyid_backup.scan_prefix(prefix).map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||
|
||||
let session_id =
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
||||
})?;
|
||||
let session_id = utils::string_from_bytes(
|
||||
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
|
||||
|
||||
let room_id = RoomId::parse(
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
utils::string_from_bytes(
|
||||
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid."))?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup room_id is invalid room id.")
|
||||
})?;
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup room_id is invalid room id."))?;
|
||||
|
||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})?;
|
||||
let key_data = serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
|
||||
|
||||
Ok::<_, Error>((room_id, session_id, key_data))
|
||||
})
|
||||
{
|
||||
}) {
|
||||
let (room_id, session_id, key_data) = result?;
|
||||
rooms
|
||||
.entry(room_id)
|
||||
|
@ -251,35 +197,28 @@ impl service::key_backups::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(version.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
Ok(self
|
||||
.backupkeyid_backup
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, value)| {
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||
|
||||
let session_id =
|
||||
utils::string_from_bytes(parts.next().ok_or_else(|| {
|
||||
Error::bad_database("backupkeyid_backup key is invalid.")
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("backupkeyid_backup session_id is invalid.")
|
||||
})?;
|
||||
let session_id = utils::string_from_bytes(
|
||||
parts.next().ok_or_else(|| Error::bad_database("backupkeyid_backup key is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("backupkeyid_backup session_id is invalid."))?;
|
||||
|
||||
let key_data = serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})?;
|
||||
let key_data = serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))?;
|
||||
|
||||
Ok::<_, Error>((session_id, key_data))
|
||||
})
|
||||
|
@ -288,35 +227,30 @@ impl service::key_backups::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
self.backupkeyid_backup
|
||||
.get(&key)?
|
||||
.map(|value| {
|
||||
serde_json::from_slice(&value).map_err(|_| {
|
||||
Error::bad_database("KeyBackupData in backupkeyid_backup is invalid.")
|
||||
})
|
||||
serde_json::from_slice(&value)
|
||||
.map_err(|_| Error::bad_database("KeyBackupData in backupkeyid_backup is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
|
@ -327,11 +261,11 @@ impl service::key_backups::Data for KeyValueDatabase {
|
|||
|
||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
self.backupkeyid_backup.remove(&outdated_key)?;
|
||||
|
@ -340,19 +274,13 @@ impl service::key_backups::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<()> {
|
||||
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(version.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(session_id.as_bytes());
|
||||
|
||||
for (outdated_key, _) in self.backupkeyid_backup.scan_prefix(key) {
|
||||
|
|
|
@ -9,31 +9,16 @@ use crate::{
|
|||
|
||||
impl service::media::Data for KeyValueDatabase {
|
||||
fn create_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_disposition
|
||||
.as_ref()
|
||||
.map(|f| f.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_type
|
||||
.as_ref()
|
||||
.map(|c| c.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default());
|
||||
|
||||
self.mediaid_file.insert(&key, &[])?;
|
||||
|
||||
|
@ -44,7 +29,7 @@ impl service::media::Data for KeyValueDatabase {
|
|||
debug!("MXC URI: {:?}", mxc);
|
||||
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
debug!("MXC db prefix: {:?}", prefix);
|
||||
|
||||
|
@ -61,7 +46,7 @@ impl service::media::Data for KeyValueDatabase {
|
|||
debug!("MXC URI: {:?}", mxc);
|
||||
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut keys: Vec<Vec<u8>> = vec![];
|
||||
|
||||
|
@ -81,16 +66,13 @@ impl service::media::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn search_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
&self, mxc: String, width: u32, height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||
let mut prefix = mxc.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(&width.to_be_bytes());
|
||||
prefix.extend_from_slice(&height.to_be_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let (key, _) = self
|
||||
.mediaid_file
|
||||
|
@ -98,34 +80,32 @@ impl service::media::Data for KeyValueDatabase {
|
|||
.next()
|
||||
.ok_or(Error::BadRequest(ErrorKind::NotFound, "Media not found"))?;
|
||||
|
||||
let mut parts = key.rsplit(|&b| b == 0xff);
|
||||
let mut parts = key.rsplit(|&b| b == 0xFF);
|
||||
|
||||
let content_type = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Content type in mediaid_file is invalid unicode.")
|
||||
})
|
||||
utils::string_from_bytes(bytes)
|
||||
.map_err(|_| Error::bad_database("Content type in mediaid_file is invalid unicode."))
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
let content_disposition_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
let content_disposition_bytes =
|
||||
parts.next().ok_or_else(|| Error::bad_database("Media ID in db is invalid."))?;
|
||||
|
||||
let content_disposition = if content_disposition_bytes.is_empty() {
|
||||
None
|
||||
} else {
|
||||
Some(
|
||||
utils::string_from_bytes(content_disposition_bytes).map_err(|_| {
|
||||
Error::bad_database("Content Disposition in mediaid_file is invalid unicode.")
|
||||
})?,
|
||||
utils::string_from_bytes(content_disposition_bytes)
|
||||
.map_err(|_| Error::bad_database("Content Disposition in mediaid_file is invalid unicode."))?,
|
||||
)
|
||||
};
|
||||
Ok((content_disposition, content_type, key))
|
||||
}
|
||||
|
||||
/// Gets all the media keys in our database (this includes all the metadata associated with it such as width, height, content-type, etc)
|
||||
/// Gets all the media keys in our database (this includes all the metadata
|
||||
/// associated with it such as width, height, content-type, etc)
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
|
||||
let mut keys: Vec<Vec<u8>> = vec![];
|
||||
|
||||
|
@ -136,44 +116,22 @@ impl service::media::Data for KeyValueDatabase {
|
|||
Ok(keys)
|
||||
}
|
||||
|
||||
fn remove_url_preview(&self, url: &str) -> Result<()> {
|
||||
self.url_previews.remove(url.as_bytes())
|
||||
}
|
||||
fn remove_url_preview(&self, url: &str) -> Result<()> { self.url_previews.remove(url.as_bytes()) }
|
||||
|
||||
fn set_url_preview(
|
||||
&self,
|
||||
url: &str,
|
||||
data: &UrlPreviewData,
|
||||
timestamp: std::time::Duration,
|
||||
) -> Result<()> {
|
||||
fn set_url_preview(&self, url: &str, data: &UrlPreviewData, timestamp: std::time::Duration) -> Result<()> {
|
||||
let mut value = Vec::<u8>::new();
|
||||
value.extend_from_slice(×tamp.as_secs().to_be_bytes());
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(
|
||||
data.title
|
||||
.as_ref()
|
||||
.map(std::string::String::as_bytes)
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(
|
||||
data.description
|
||||
.as_ref()
|
||||
.map(std::string::String::as_bytes)
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
value.push(0xff);
|
||||
value.extend_from_slice(
|
||||
data.image
|
||||
.as_ref()
|
||||
.map(std::string::String::as_bytes)
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
value.push(0xff);
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(data.title.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(data.description.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(data.image.as_ref().map(std::string::String::as_bytes).unwrap_or_default());
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(&data.image_size.unwrap_or(0).to_be_bytes());
|
||||
value.push(0xff);
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(&data.image_width.unwrap_or(0).to_be_bytes());
|
||||
value.push(0xff);
|
||||
value.push(0xFF);
|
||||
value.extend_from_slice(&data.image_height.unwrap_or(0).to_be_bytes());
|
||||
|
||||
self.url_previews.insert(url.as_bytes(), &value)
|
||||
|
@ -182,54 +140,33 @@ impl service::media::Data for KeyValueDatabase {
|
|||
fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
|
||||
let values = self.url_previews.get(url.as_bytes()).ok()??;
|
||||
|
||||
let mut values = values.split(|&b| b == 0xff);
|
||||
let mut values = values.split(|&b| b == 0xFF);
|
||||
|
||||
let _ts = match values
|
||||
.next()
|
||||
.map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
let _ts = match values.next().map(|b| u64::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let title = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
let title = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let description = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
let description = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let image = match values
|
||||
.next()
|
||||
.and_then(|b| String::from_utf8(b.to_vec()).ok())
|
||||
{
|
||||
let image = match values.next().and_then(|b| String::from_utf8(b.to_vec()).ok()) {
|
||||
Some(s) if s.is_empty() => None,
|
||||
x => x,
|
||||
};
|
||||
let image_size = match values
|
||||
.next()
|
||||
.map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
let image_size = match values.next().map(|b| usize::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let image_width = match values
|
||||
.next()
|
||||
.map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
let image_width = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
let image_height = match values
|
||||
.next()
|
||||
.map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array")))
|
||||
{
|
||||
let image_height = match values.next().map(|b| u32::from_be_bytes(b.try_into().expect("valid BE array"))) {
|
||||
Some(0) => None,
|
||||
x => x,
|
||||
};
|
||||
|
|
|
@ -10,66 +10,50 @@ impl service::pusher::Data for KeyValueDatabase {
|
|||
match &pusher {
|
||||
set_pusher::v3::PusherAction::Post(data) => {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(data.pusher.ids.pushkey.as_bytes());
|
||||
self.senderkey_pusher.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"),
|
||||
)?;
|
||||
self.senderkey_pusher
|
||||
.insert(&key, &serde_json::to_vec(&pusher).expect("Pusher is valid JSON value"))?;
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
set_pusher::v3::PusherAction::Delete(ids) => {
|
||||
let mut key = sender.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(ids.pushkey.as_bytes());
|
||||
self.senderkey_pusher
|
||||
.remove(&key)
|
||||
.map(|_| ())
|
||||
.map_err(Into::into)
|
||||
}
|
||||
self.senderkey_pusher.remove(&key).map(|_| ()).map_err(Into::into)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
fn get_pusher(&self, sender: &UserId, pushkey: &str) -> Result<Option<Pusher>> {
|
||||
let mut senderkey = sender.as_bytes().to_vec();
|
||||
senderkey.push(0xff);
|
||||
senderkey.push(0xFF);
|
||||
senderkey.extend_from_slice(pushkey.as_bytes());
|
||||
|
||||
self.senderkey_pusher
|
||||
.get(&senderkey)?
|
||||
.map(|push| {
|
||||
serde_json::from_slice(&push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.map(|push| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
self.senderkey_pusher
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, push)| {
|
||||
serde_json::from_slice(&push)
|
||||
.map_err(|_| Error::bad_database("Invalid Pusher in db."))
|
||||
})
|
||||
.map(|(_, push)| serde_json::from_slice(&push).map_err(|_| Error::bad_database("Invalid Pusher in db.")))
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn get_pushkeys<'a>(
|
||||
&'a self,
|
||||
sender: &UserId,
|
||||
) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
|
||||
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a> {
|
||||
let mut prefix = sender.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(self.senderkey_pusher.scan_prefix(prefix).map(|(k, _)| {
|
||||
let mut parts = k.splitn(2, |&b| b == 0xff);
|
||||
let mut parts = k.splitn(2, |&b| b == 0xFF);
|
||||
let _senderkey = parts.next();
|
||||
let push_key = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
|
||||
let push_key = parts.next().ok_or_else(|| Error::bad_database("Invalid senderkey_pusher in db"))?;
|
||||
let push_key_string = utils::string_from_bytes(push_key)
|
||||
.map_err(|_| Error::bad_database("Invalid pusher bytes in senderkey_pusher"))?;
|
||||
|
||||
|
|
|
@ -4,10 +4,9 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}
|
|||
|
||||
impl service::rooms::alias::Data for KeyValueDatabase {
|
||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
||||
self.alias_roomid
|
||||
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
|
||||
self.alias_roomid.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
|
||||
let mut aliasid = room_id.as_bytes().to_vec();
|
||||
aliasid.push(0xff);
|
||||
aliasid.push(0xFF);
|
||||
aliasid.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
|
||||
Ok(())
|
||||
|
@ -16,17 +15,14 @@ impl service::rooms::alias::Data for KeyValueDatabase {
|
|||
fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
||||
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
|
||||
let mut prefix = room_id;
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
|
||||
self.aliasid_alias.remove(&key)?;
|
||||
}
|
||||
self.alias_roomid.remove(alias.alias().as_bytes())?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"Alias does not exist.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist."));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
@ -35,20 +31,20 @@ impl service::rooms::alias::Data for KeyValueDatabase {
|
|||
self.alias_roomid
|
||||
.get(alias.alias().as_bytes())?
|
||||
.map(|bytes| {
|
||||
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in alias_roomid is invalid unicode.")
|
||||
})?)
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn local_aliases_for_room<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
&'a self, room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
|
||||
utils::string_from_bytes(&bytes)
|
||||
|
@ -58,27 +54,17 @@ impl service::rooms::alias::Data for KeyValueDatabase {
|
|||
}))
|
||||
}
|
||||
|
||||
fn all_local_aliases<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||
Box::new(
|
||||
self.alias_roomid
|
||||
.iter()
|
||||
.map(|(room_alias_bytes, room_id_bytes)| {
|
||||
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||
Box::new(self.alias_roomid.iter().map(|(room_alias_bytes, room_id_bytes)| {
|
||||
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid alias bytes in aliasid_alias.")
|
||||
})?;
|
||||
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
|
||||
|
||||
let room_id = utils::string_from_bytes(&room_id_bytes)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid room_id bytes in aliasid_alias.")
|
||||
})?
|
||||
.map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
|
||||
.try_into()
|
||||
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
|
||||
|
||||
Ok((room_id, room_alias_localpart))
|
||||
}),
|
||||
)
|
||||
}))
|
||||
}
|
||||
}
|
||||
|
|
|
@ -12,10 +12,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
|
|||
// We only save auth chains for single events in the db
|
||||
if key.len() == 1 {
|
||||
// Check DB cache
|
||||
let chain = self
|
||||
.shorteventid_authchain
|
||||
.get(&key[0].to_be_bytes())?
|
||||
.map(|chain| {
|
||||
let chain = self.shorteventid_authchain.get(&key[0].to_be_bytes())?.map(|chain| {
|
||||
chain
|
||||
.chunks_exact(size_of::<u64>())
|
||||
.map(|chunk| utils::u64_from_bytes(chunk).expect("byte length is correct"))
|
||||
|
@ -26,10 +23,7 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
|
|||
let chain = Arc::new(chain);
|
||||
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(vec![key[0]], Arc::clone(&chain));
|
||||
self.auth_chain_cache.lock().unwrap().insert(vec![key[0]], Arc::clone(&chain));
|
||||
|
||||
return Ok(Some(chain));
|
||||
}
|
||||
|
@ -43,18 +37,12 @@ impl service::rooms::auth_chain::Data for KeyValueDatabase {
|
|||
if key.len() == 1 {
|
||||
self.shorteventid_authchain.insert(
|
||||
&key[0].to_be_bytes(),
|
||||
&auth_chain
|
||||
.iter()
|
||||
.flat_map(|s| s.to_be_bytes().to_vec())
|
||||
.collect::<Vec<u8>>(),
|
||||
&auth_chain.iter().flat_map(|s| s.to_be_bytes().to_vec()).collect::<Vec<u8>>(),
|
||||
)?;
|
||||
}
|
||||
|
||||
// Cache in RAM
|
||||
self.auth_chain_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(key, auth_chain);
|
||||
self.auth_chain_cache.lock().unwrap().insert(key, auth_chain);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
|
|
@ -3,13 +3,9 @@ use ruma::{OwnedRoomId, RoomId};
|
|||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::rooms::directory::Data for KeyValueDatabase {
|
||||
fn set_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.publicroomids.insert(room_id.as_bytes(), &[])
|
||||
}
|
||||
fn set_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.insert(room_id.as_bytes(), &[]) }
|
||||
|
||||
fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.publicroomids.remove(room_id.as_bytes())
|
||||
}
|
||||
fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.publicroomids.remove(room_id.as_bytes()) }
|
||||
|
||||
fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
|
||||
|
@ -18,9 +14,8 @@ impl service::rooms::directory::Data for KeyValueDatabase {
|
|||
fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in publicroomids is invalid unicode.")
|
||||
})?,
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
|
||||
}))
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
use std::time::Duration;
|
||||
|
||||
use ruma::{
|
||||
events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId,
|
||||
};
|
||||
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId};
|
||||
use tracing::error;
|
||||
|
||||
use crate::{
|
||||
|
@ -63,18 +61,11 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
|||
presence.last_count = count;
|
||||
|
||||
presence
|
||||
}
|
||||
None => Presence::new(
|
||||
new_state.clone(),
|
||||
new_state == PresenceState::Online,
|
||||
now,
|
||||
count,
|
||||
None,
|
||||
),
|
||||
},
|
||||
None => Presence::new(new_state.clone(), new_state == PresenceState::Online, now, count, None),
|
||||
};
|
||||
|
||||
self.roomuserid_presence
|
||||
.insert(&key, &new_presence.to_json_bytes()?)?;
|
||||
self.roomuserid_presence.insert(&key, &new_presence.to_json_bytes()?)?;
|
||||
}
|
||||
|
||||
let timeout = match new_state {
|
||||
|
@ -82,22 +73,15 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
|||
_ => services().globals.config.presence_offline_timeout_s,
|
||||
};
|
||||
|
||||
self.presence_timer_sender
|
||||
.send((user_id.to_owned(), Duration::from_secs(timeout)))
|
||||
.map_err(|e| {
|
||||
self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| {
|
||||
error!("Failed to add presence timer: {}", e);
|
||||
Error::bad_database("Failed to add presence timer")
|
||||
})
|
||||
}
|
||||
|
||||
fn set_presence(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
presence_state: PresenceState,
|
||||
currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>,
|
||||
status_msg: Option<String>,
|
||||
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||
) -> Result<()> {
|
||||
let now = utils::millis_since_unix_epoch();
|
||||
let last_active_ts = match last_active_ago {
|
||||
|
@ -120,15 +104,12 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
|||
_ => services().globals.config.presence_offline_timeout_s,
|
||||
};
|
||||
|
||||
self.presence_timer_sender
|
||||
.send((user_id.to_owned(), Duration::from_secs(timeout)))
|
||||
.map_err(|e| {
|
||||
self.presence_timer_sender.send((user_id.to_owned(), Duration::from_secs(timeout))).map_err(|e| {
|
||||
error!("Failed to add presence timer: {}", e);
|
||||
Error::bad_database("Failed to add presence timer")
|
||||
})?;
|
||||
|
||||
self.roomuserid_presence
|
||||
.insert(&key, &presence.to_json_bytes()?)?;
|
||||
self.roomuserid_presence.insert(&key, &presence.to_json_bytes()?)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -144,29 +125,25 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn presence_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
&'a self, room_id: &RoomId, since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a> {
|
||||
let prefix = [room_id.as_bytes(), &[0xff]].concat();
|
||||
let prefix = [room_id.as_bytes(), &[0xFF]].concat();
|
||||
|
||||
Box::new(
|
||||
self.roomuserid_presence
|
||||
.scan_prefix(prefix)
|
||||
.flat_map(
|
||||
|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> {
|
||||
.flat_map(|(key, presence_bytes)| -> Result<(OwnedUserId, u64, PresenceEvent)> {
|
||||
let user_id = user_id_from_bytes(
|
||||
key.rsplit(|byte| *byte == 0xff).next().ok_or_else(|| {
|
||||
Error::bad_database("No UserID bytes in presence key")
|
||||
})?,
|
||||
key.rsplit(|byte| *byte == 0xFF)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("No UserID bytes in presence key"))?,
|
||||
)?;
|
||||
|
||||
let presence = Presence::from_json_bytes(&presence_bytes)?;
|
||||
let presence_event = presence.to_presence_event(&user_id)?;
|
||||
|
||||
Ok((user_id, presence.last_count, presence_event))
|
||||
},
|
||||
)
|
||||
})
|
||||
.filter(move |(_, count, _)| *count > since),
|
||||
)
|
||||
}
|
||||
|
@ -174,5 +151,5 @@ impl service::rooms::edus::presence::Data for KeyValueDatabase {
|
|||
|
||||
#[inline]
|
||||
fn presence_key(room_id: &RoomId, user_id: &UserId) -> Vec<u8> {
|
||||
[room_id.as_bytes(), &[0xff], user_id.as_bytes()].concat()
|
||||
[room_id.as_bytes(), &[0xFF], user_id.as_bytes()].concat()
|
||||
}
|
||||
|
|
|
@ -1,20 +1,13 @@
|
|||
use std::mem;
|
||||
|
||||
use ruma::{
|
||||
events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, OwnedUserId, RoomId, UserId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
||||
fn readreceipt_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
event: ReceiptEvent,
|
||||
) -> Result<()> {
|
||||
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut last_possible_key = prefix.clone();
|
||||
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
|
||||
|
@ -25,19 +18,15 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
|||
.iter_from(&last_possible_key, true)
|
||||
.take_while(|(key, _)| key.starts_with(&prefix))
|
||||
.find(|(key, _)| {
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element")
|
||||
== user_id.as_bytes()
|
||||
})
|
||||
{
|
||||
key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element") == user_id.as_bytes()
|
||||
}) {
|
||||
// This is the old room_latest
|
||||
self.readreceiptid_readreceipt.remove(&old)?;
|
||||
}
|
||||
|
||||
let mut room_latest_id = prefix;
|
||||
room_latest_id.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
room_latest_id.push(0xff);
|
||||
room_latest_id.push(0xFF);
|
||||
room_latest_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.readreceiptid_readreceipt.insert(
|
||||
|
@ -49,20 +38,10 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn readreceipts_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
) -> Box<
|
||||
dyn Iterator<
|
||||
Item = Result<(
|
||||
OwnedUserId,
|
||||
u64,
|
||||
Raw<ruma::events::AnySyncEphemeralRoomEvent>,
|
||||
)>,
|
||||
> + 'a,
|
||||
> {
|
||||
&'a self, room_id: &RoomId, since: u64,
|
||||
) -> Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<ruma::events::AnySyncEphemeralRoomEvent>)>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
let prefix2 = prefix.clone();
|
||||
|
||||
let mut first_possible_edu = prefix.clone();
|
||||
|
@ -73,33 +52,22 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
|||
.iter_from(&first_possible_edu, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix2))
|
||||
.map(move |(k, v)| {
|
||||
let count = utils::u64_from_bytes(
|
||||
&k[prefix.len()..prefix.len() + mem::size_of::<u64>()],
|
||||
)
|
||||
let count = utils::u64_from_bytes(&k[prefix.len()..prefix.len() + mem::size_of::<u64>()])
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&k[prefix.len() + mem::size_of::<u64>() + 1..])
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Invalid readreceiptid userid bytes in db.")
|
||||
})?,
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
|
||||
|
||||
let mut json =
|
||||
serde_json::from_slice::<CanonicalJsonObject>(&v).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Read receipt in roomlatestid_roomlatest is invalid json.",
|
||||
)
|
||||
})?;
|
||||
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
|
||||
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
|
||||
json.remove("room_id");
|
||||
|
||||
Ok((
|
||||
user_id,
|
||||
count,
|
||||
Raw::from_json(
|
||||
serde_json::value::to_raw_value(&json)
|
||||
.expect("json is valid raw value"),
|
||||
),
|
||||
Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
|
||||
))
|
||||
}),
|
||||
)
|
||||
|
@ -107,42 +75,37 @@ impl service::rooms::edus::read_receipt::Data for KeyValueDatabase {
|
|||
|
||||
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.insert(&key, &count.to_be_bytes())?;
|
||||
self.roomuserid_privateread.insert(&key, &count.to_be_bytes())?;
|
||||
|
||||
self.roomuserid_lastprivatereadupdate
|
||||
.insert(&key, &services().globals.next_count()?.to_be_bytes())
|
||||
self.roomuserid_lastprivatereadupdate.insert(&key, &services().globals.next_count()?.to_be_bytes())
|
||||
}
|
||||
|
||||
fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_privateread
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |v| {
|
||||
Ok(Some(utils::u64_from_bytes(&v).map_err(|_| {
|
||||
Error::bad_database("Invalid private read marker bytes")
|
||||
})?))
|
||||
self.roomuserid_privateread.get(&key)?.map_or(Ok(None), |v| {
|
||||
Ok(Some(
|
||||
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
Ok(self
|
||||
.roomuserid_lastprivatereadupdate
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
||||
})
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
|
|
|
@ -7,47 +7,38 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}
|
|||
impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
||||
fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let count = services().globals.next_count()?.to_be_bytes();
|
||||
|
||||
let mut room_typing_id = prefix;
|
||||
room_typing_id.extend_from_slice(&timeout.to_be_bytes());
|
||||
room_typing_id.push(0xff);
|
||||
room_typing_id.push(0xFF);
|
||||
room_typing_id.extend_from_slice(&count);
|
||||
|
||||
self.typingid_userid
|
||||
.insert(&room_typing_id, user_id.as_bytes())?;
|
||||
self.typingid_userid.insert(&room_typing_id, user_id.as_bytes())?;
|
||||
|
||||
self.roomid_lasttypingupdate
|
||||
.insert(room_id.as_bytes(), &count)?;
|
||||
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &count)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let user_id = user_id.to_string();
|
||||
|
||||
let mut found_outdated = false;
|
||||
|
||||
// Maybe there are multiple ones from calling roomtyping_add multiple times
|
||||
for outdated_edu in self
|
||||
.typingid_userid
|
||||
.scan_prefix(prefix)
|
||||
.filter(|(_, v)| &**v == user_id.as_bytes())
|
||||
{
|
||||
for outdated_edu in self.typingid_userid.scan_prefix(prefix).filter(|(_, v)| &**v == user_id.as_bytes()) {
|
||||
self.typingid_userid.remove(&outdated_edu.0)?;
|
||||
found_outdated = true;
|
||||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(
|
||||
room_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -55,7 +46,7 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
|||
|
||||
fn typings_maintain(&self, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let current_timestamp = utils::millis_since_unix_epoch();
|
||||
|
||||
|
@ -69,9 +60,9 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
|||
Ok::<_, Error>((
|
||||
key.clone(),
|
||||
utils::u64_from_bytes(
|
||||
&key.splitn(2, |&b| b == 0xff).nth(1).ok_or_else(|| {
|
||||
Error::bad_database("RoomTyping has invalid timestamp or delimiters.")
|
||||
})?[0..mem::size_of::<u64>()],
|
||||
&key.splitn(2, |&b| b == 0xFF)
|
||||
.nth(1)
|
||||
.ok_or_else(|| Error::bad_database("RoomTyping has invalid timestamp or delimiters."))?[0..mem::size_of::<u64>()],
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomTyping has invalid timestamp bytes."))?,
|
||||
))
|
||||
|
@ -85,10 +76,7 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
if found_outdated {
|
||||
self.roomid_lasttypingupdate.insert(
|
||||
room_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
self.roomid_lasttypingupdate.insert(room_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
|
@ -99,9 +87,8 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
|||
.roomid_lasttypingupdate
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
|
||||
})
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
|
@ -109,14 +96,15 @@ impl service::rooms::edus::typing::Data for KeyValueDatabase {
|
|||
|
||||
fn typings_all(&self, room_id: &RoomId) -> Result<HashSet<OwnedUserId>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut user_ids = HashSet::new();
|
||||
|
||||
for (_, user_id) in self.typingid_userid.scan_prefix(prefix) {
|
||||
let user_id = UserId::parse(utils::string_from_bytes(&user_id).map_err(|_| {
|
||||
Error::bad_database("User ID in typingid_userid is invalid unicode.")
|
||||
})?)
|
||||
let user_id = UserId::parse(
|
||||
utils::string_from_bytes(&user_id)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in typingid_userid is invalid."))?;
|
||||
|
||||
user_ids.insert(user_id);
|
||||
|
|
|
@ -4,35 +4,28 @@ use crate::{database::KeyValueDatabase, service, Result};
|
|||
|
||||
impl service::rooms::lazy_loading::Data for KeyValueDatabase {
|
||||
fn lazy_load_was_sent_before(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
ll_user: &UserId,
|
||||
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
|
||||
) -> Result<bool> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(ll_user.as_bytes());
|
||||
Ok(self.lazyloadedids.get(&key)?.is_some())
|
||||
}
|
||||
|
||||
fn lazy_load_confirm_delivery(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
|
||||
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
|
||||
) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
for ll_id in confirmed_user_ids {
|
||||
let mut key = prefix.clone();
|
||||
|
@ -43,18 +36,13 @@ impl service::rooms::lazy_loading::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn lazy_load_reset(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<()> {
|
||||
fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(room_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
|
||||
self.lazyloadedids.remove(&key)?;
|
||||
|
|
|
@ -11,20 +11,14 @@ impl service::rooms::metadata::Data for KeyValueDatabase {
|
|||
};
|
||||
|
||||
// Look for PDUs in that room.
|
||||
Ok(self
|
||||
.pduid_pdu
|
||||
.iter_from(&prefix, false)
|
||||
.next()
|
||||
.filter(|(k, _)| k.starts_with(&prefix))
|
||||
.is_some())
|
||||
Ok(self.pduid_pdu.iter_from(&prefix, false).next().filter(|(k, _)| k.starts_with(&prefix)).is_some())
|
||||
}
|
||||
|
||||
fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Room ID in publicroomids is invalid unicode.")
|
||||
})?,
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
|
||||
}))
|
||||
|
@ -44,9 +38,7 @@ impl service::rooms::metadata::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn is_banned(&self, room_id: &RoomId) -> Result<bool> {
|
||||
Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some())
|
||||
}
|
||||
fn is_banned(&self, room_id: &RoomId) -> Result<bool> { Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some()) }
|
||||
|
||||
fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
|
||||
if banned {
|
||||
|
|
|
@ -4,17 +4,13 @@ use crate::{database::KeyValueDatabase, service, Error, PduEvent, Result};
|
|||
|
||||
impl service::rooms::outlier::Data for KeyValueDatabase {
|
||||
fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |pdu| {
|
||||
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
|
||||
fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |pdu| {
|
||||
self.eventid_outlierpdu.get(event_id.as_bytes())?.map_or(Ok(None), |pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
}
|
||||
|
|
|
@ -17,11 +17,7 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn relations_until<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
shortroomid: u64,
|
||||
target: u64,
|
||||
until: PduCount,
|
||||
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let prefix = target.to_be_bytes().to_vec();
|
||||
let mut current = prefix.clone();
|
||||
|
@ -31,15 +27,13 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
|||
PduCount::Backfilled(x) => {
|
||||
current.extend_from_slice(&0_u64.to_be_bytes());
|
||||
u64::MAX - x - 1
|
||||
}
|
||||
},
|
||||
};
|
||||
current.extend_from_slice(&count_raw.to_be_bytes());
|
||||
|
||||
Ok(Box::new(
|
||||
self.tofrom_relation
|
||||
.iter_from(¤t, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(tofrom, _data)| {
|
||||
self.tofrom_relation.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(tofrom, _data)| {
|
||||
let from = utils::u64_from_bytes(&tofrom[(mem::size_of::<u64>())..])
|
||||
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
|
||||
|
||||
|
@ -55,7 +49,8 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
|||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((PduCount::Normal(from), pdu))
|
||||
}),
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
|
@ -80,8 +75,6 @@ impl service::rooms::pdu_metadata::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
|
||||
self.softfailedeventids
|
||||
.get(event_id.as_bytes())
|
||||
.map(|o| o.is_some())
|
||||
self.softfailedeventids.get(event_id.as_bytes()).map(|o| o.is_some())
|
||||
}
|
||||
}
|
||||
|
|
|
@ -14,7 +14,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
|
|||
.map(|word| {
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(word.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
|
||||
(key, Vec::new())
|
||||
});
|
||||
|
@ -23,13 +23,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec();
|
||||
|
||||
let words: Vec<_> = search_string
|
||||
.split_terminator(|c: char| !c.is_alphanumeric())
|
||||
|
@ -40,7 +34,7 @@ impl service::rooms::search::Data for KeyValueDatabase {
|
|||
let iterators = words.clone().into_iter().map(move |word| {
|
||||
let mut prefix2 = prefix.clone();
|
||||
prefix2.extend_from_slice(word.as_bytes());
|
||||
prefix2.push(0xff);
|
||||
prefix2.push(0xFF);
|
||||
let prefix3 = prefix2.clone();
|
||||
|
||||
let mut last_possible_id = prefix2.clone();
|
||||
|
|
|
@ -12,79 +12,57 @@ impl service::rooms::short::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
let short = match self.eventid_shorteventid.get(event_id.as_bytes())? {
|
||||
Some(shorteventid) => utils::u64_from_bytes(&shorteventid)
|
||||
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
|
||||
Some(shorteventid) => {
|
||||
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
|
||||
},
|
||||
None => {
|
||||
let shorteventid = services().globals.next_count()?;
|
||||
self.eventid_shorteventid
|
||||
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
|
||||
self.shorteventid_eventid
|
||||
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
|
||||
self.eventid_shorteventid.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
|
||||
self.shorteventid_eventid.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
|
||||
shorteventid
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
self.eventidshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(event_id.to_owned(), short);
|
||||
self.eventidshort_cache.lock().unwrap().insert(event_id.to_owned(), short);
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_shortstatekey(
|
||||
&self,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<Option<u64>> {
|
||||
if let Some(short) = self
|
||||
.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
|
||||
if let Some(short) =
|
||||
self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(Some(*short));
|
||||
}
|
||||
|
||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||
statekey_vec.push(0xff);
|
||||
statekey_vec.push(0xFF);
|
||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||
|
||||
let short = self
|
||||
.statekey_shortstatekey
|
||||
.get(&statekey_vec)?
|
||||
.map(|shortstatekey| {
|
||||
utils::u64_from_bytes(&shortstatekey)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
|
||||
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
|
||||
})
|
||||
.transpose()?;
|
||||
|
||||
if let Some(s) = short {
|
||||
self.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert((event_type.clone(), state_key.to_owned()), s);
|
||||
self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), s);
|
||||
}
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_or_create_shortstatekey(
|
||||
&self,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
) -> Result<u64> {
|
||||
if let Some(short) = self
|
||||
.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
|
||||
if let Some(short) =
|
||||
self.statekeyshort_cache.lock().unwrap().get_mut(&(event_type.clone(), state_key.to_owned()))
|
||||
{
|
||||
return Ok(*short);
|
||||
}
|
||||
|
||||
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
|
||||
statekey_vec.push(0xff);
|
||||
statekey_vec.push(0xFF);
|
||||
statekey_vec.extend_from_slice(state_key.as_bytes());
|
||||
|
||||
let short = match self.statekey_shortstatekey.get(&statekey_vec)? {
|
||||
|
@ -92,29 +70,19 @@ impl service::rooms::short::Data for KeyValueDatabase {
|
|||
.map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?,
|
||||
None => {
|
||||
let shortstatekey = services().globals.next_count()?;
|
||||
self.statekey_shortstatekey
|
||||
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
|
||||
self.shortstatekey_statekey
|
||||
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
|
||||
self.statekey_shortstatekey.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
|
||||
self.shortstatekey_statekey.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
|
||||
shortstatekey
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
self.statekeyshort_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert((event_type.clone(), state_key.to_owned()), short);
|
||||
self.statekeyshort_cache.lock().unwrap().insert((event_type.clone(), state_key.to_owned()), short);
|
||||
|
||||
Ok(short)
|
||||
}
|
||||
|
||||
fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
|
||||
if let Some(id) = self
|
||||
.shorteventid_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&shorteventid)
|
||||
{
|
||||
if let Some(id) = self.shorteventid_cache.lock().unwrap().get_mut(&shorteventid) {
|
||||
return Ok(Arc::clone(id));
|
||||
}
|
||||
|
||||
|
@ -123,26 +91,19 @@ impl service::rooms::short::Data for KeyValueDatabase {
|
|||
.get(&shorteventid.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
|
||||
|
||||
let event_id = EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("EventID in shorteventid_eventid is invalid unicode.")
|
||||
})?)
|
||||
let event_id = EventId::parse_arc(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
|
||||
|
||||
self.shorteventid_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(shorteventid, Arc::clone(&event_id));
|
||||
self.shorteventid_cache.lock().unwrap().insert(shorteventid, Arc::clone(&event_id));
|
||||
|
||||
Ok(event_id)
|
||||
}
|
||||
|
||||
fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
|
||||
if let Some(id) = self
|
||||
.shortstatekey_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.get_mut(&shortstatekey)
|
||||
{
|
||||
if let Some(id) = self.shortstatekey_cache.lock().unwrap().get_mut(&shortstatekey) {
|
||||
return Ok(id.clone());
|
||||
}
|
||||
|
||||
|
@ -151,28 +112,22 @@ impl service::rooms::short::Data for KeyValueDatabase {
|
|||
.get(&shortstatekey.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
|
||||
|
||||
let mut parts = bytes.splitn(2, |&b| b == 0xff);
|
||||
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
|
||||
let eventtype_bytes = parts.next().expect("split always returns one entry");
|
||||
let statekey_bytes = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
|
||||
let statekey_bytes =
|
||||
parts.next().ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
|
||||
|
||||
let event_type =
|
||||
StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
|
||||
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
|
||||
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
|
||||
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
|
||||
})?);
|
||||
|
||||
let state_key = utils::string_from_bytes(statekey_bytes).map_err(|_| {
|
||||
Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode.")
|
||||
})?;
|
||||
let state_key = utils::string_from_bytes(statekey_bytes)
|
||||
.map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
|
||||
|
||||
let result = (event_type, state_key);
|
||||
|
||||
self.shortstatekey_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(shortstatekey, result.clone());
|
||||
self.shortstatekey_cache.lock().unwrap().insert(shortstatekey, result.clone());
|
||||
|
||||
Ok(result)
|
||||
}
|
||||
|
@ -187,33 +142,29 @@ impl service::rooms::short::Data for KeyValueDatabase {
|
|||
),
|
||||
None => {
|
||||
let shortstatehash = services().globals.next_count()?;
|
||||
self.statehash_shortstatehash
|
||||
.insert(state_hash, &shortstatehash.to_be_bytes())?;
|
||||
self.statehash_shortstatehash.insert(state_hash, &shortstatehash.to_be_bytes())?;
|
||||
(shortstatehash, false)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortroomid
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))
|
||||
})
|
||||
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
|
||||
Ok(match self.roomid_shortroomid.get(room_id.as_bytes())? {
|
||||
Some(short) => utils::u64_from_bytes(&short)
|
||||
.map_err(|_| Error::bad_database("Invalid shortroomid in db."))?,
|
||||
Some(short) => {
|
||||
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
|
||||
},
|
||||
None => {
|
||||
let short = services().globals.next_count()?;
|
||||
self.roomid_shortroomid
|
||||
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
|
||||
self.roomid_shortroomid.insert(room_id.as_bytes(), &short.to_be_bytes())?;
|
||||
short
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,16 +1,13 @@
|
|||
use ruma::{EventId, OwnedEventId, RoomId};
|
||||
use std::collections::HashSet;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use std::sync::Arc;
|
||||
use ruma::{EventId, OwnedEventId, RoomId};
|
||||
use tokio::sync::MutexGuard;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, utils, Error, Result};
|
||||
|
||||
impl service::rooms::state::Data for KeyValueDatabase {
|
||||
fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_shortstatehash
|
||||
.get(room_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
self.roomid_shortstatehash.get(room_id.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
|
||||
})?))
|
||||
|
@ -23,27 +20,26 @@ impl service::rooms::state::Data for KeyValueDatabase {
|
|||
new_shortstatehash: u64,
|
||||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
self.roomid_shortstatehash
|
||||
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
|
||||
self.roomid_shortstatehash.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
|
||||
self.shorteventid_shortstatehash
|
||||
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
|
||||
self.shorteventid_shortstatehash.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
self.roomid_pduleaves
|
||||
.scan_prefix(prefix)
|
||||
.map(|(_, bytes)| {
|
||||
EventId::parse_arc(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("EventID in roomid_pduleaves is invalid unicode.")
|
||||
})?)
|
||||
EventId::parse_arc(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
|
||||
})
|
||||
.collect()
|
||||
|
@ -56,7 +52,7 @@ impl service::rooms::state::Data for KeyValueDatabase {
|
|||
_mutex_lock: &MutexGuard<'_, ()>, // Take mutex guard to make sure users get the room state mutex
|
||||
) -> Result<()> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
|
||||
self.roomid_pduleaves.remove(&key)?;
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
use async_trait::async_trait;
|
||||
use ruma::{events::StateEventType, EventId, RoomId};
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
|
||||
#[async_trait]
|
||||
impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
||||
async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
|
||||
|
@ -17,10 +18,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
|||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.iter() {
|
||||
let parsed = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)?;
|
||||
let parsed = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
|
||||
result.insert(parsed.0, parsed.1);
|
||||
|
||||
i += 1;
|
||||
|
@ -31,10 +29,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
async fn state_full(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
let full_state = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
|
@ -46,10 +41,7 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
|||
let mut result = HashMap::new();
|
||||
let mut i = 0;
|
||||
for compressed in full_state.iter() {
|
||||
let (_, eventid) = services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)?;
|
||||
let (_, eventid) = services().rooms.state_compressor.parse_compressed_state_event(compressed)?;
|
||||
if let Some(pdu) = services().rooms.timeline.get_pdu(&eventid)? {
|
||||
result.insert(
|
||||
(
|
||||
|
@ -72,18 +64,12 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
|||
Ok(result)
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn state_get_id(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
let shortstatekey = match services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortstatekey(event_type, state_key)?
|
||||
{
|
||||
let shortstatekey = match services().rooms.short.get_shortstatekey(event_type, state_key)? {
|
||||
Some(s) => s,
|
||||
None => return Ok(None),
|
||||
};
|
||||
|
@ -94,90 +80,62 @@ impl service::rooms::state_accessor::Data for KeyValueDatabase {
|
|||
.pop()
|
||||
.expect("there is always one layer")
|
||||
.1;
|
||||
Ok(full_state
|
||||
.iter()
|
||||
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
|
||||
.and_then(|compressed| {
|
||||
services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.parse_compressed_state_event(compressed)
|
||||
.ok()
|
||||
.map(|(_, id)| id)
|
||||
}))
|
||||
Ok(
|
||||
full_state.iter().find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes())).and_then(|compressed| {
|
||||
services().rooms.state_compressor.parse_compressed_state_event(compressed).ok().map(|(_, id)| id)
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn state_get(
|
||||
&self,
|
||||
shortstatehash: u64,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
self.state_get_id(shortstatehash, event_type, state_key)?
|
||||
.map_or(Ok(None), |event_id| {
|
||||
services().rooms.timeline.get_pdu(&event_id)
|
||||
})
|
||||
.map_or(Ok(None), |event_id| services().rooms.timeline.get_pdu(&event_id))
|
||||
}
|
||||
|
||||
/// Returns the state hash for this pdu.
|
||||
fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
|
||||
self.eventid_shorteventid
|
||||
.get(event_id.as_bytes())?
|
||||
.map_or(Ok(None), |shorteventid| {
|
||||
self.eventid_shorteventid.get(event_id.as_bytes())?.map_or(Ok(None), |shorteventid| {
|
||||
self.shorteventid_shortstatehash
|
||||
.get(&shorteventid)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"Invalid shortstatehash bytes in shorteventid_shortstatehash",
|
||||
)
|
||||
})
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash"))
|
||||
})
|
||||
.transpose()
|
||||
})
|
||||
}
|
||||
|
||||
/// Returns the full room state.
|
||||
async fn room_state_full(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
async fn room_state_full(&self, room_id: &RoomId) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||
self.state_full(current_shortstatehash).await
|
||||
} else {
|
||||
Ok(HashMap::new())
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn room_state_get_id(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<EventId>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||
self.state_get_id(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`, `state_key`).
|
||||
/// Returns a single PDU from `room_id` with key (`event_type`,
|
||||
/// `state_key`).
|
||||
fn room_state_get(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
event_type: &StateEventType,
|
||||
state_key: &str,
|
||||
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
|
||||
) -> Result<Option<Arc<PduEvent>>> {
|
||||
if let Some(current_shortstatehash) =
|
||||
services().rooms.state.get_room_shortstatehash(room_id)?
|
||||
{
|
||||
if let Some(current_shortstatehash) = services().rooms.state.get_room_shortstatehash(room_id)? {
|
||||
self.state_get(current_shortstatehash, event_type, state_key)
|
||||
} else {
|
||||
Ok(None)
|
||||
|
|
|
@ -10,27 +10,25 @@ use ruma::{
|
|||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, Result};
|
||||
|
||||
type StrippedStateEventIter<'a> =
|
||||
Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
|
||||
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
|
||||
|
||||
type AnySyncStateEventIter<'a> =
|
||||
Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
|
||||
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
|
||||
|
||||
impl service::rooms::state_cache::Data for KeyValueDatabase {
|
||||
fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
self.roomuseroncejoinedids.insert(&userroom_id, &[])
|
||||
}
|
||||
|
||||
fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.push(0xFF);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_joined.insert(&userroom_id, &[])?;
|
||||
|
@ -44,28 +42,21 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn mark_as_invited(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
||||
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
|
||||
) -> Result<()> {
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.push(0xFF);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_invitestate.insert(
|
||||
&userroom_id,
|
||||
&serde_json::to_vec(&last_state.unwrap_or_default())
|
||||
.expect("state to bytes always works"),
|
||||
)?;
|
||||
self.roomuserid_invitecount.insert(
|
||||
&roomuser_id,
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
|
||||
)?;
|
||||
self.roomuserid_invitecount.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
|
||||
self.userroomid_joined.remove(&userroom_id)?;
|
||||
self.roomuserid_joined.remove(&roomuser_id)?;
|
||||
self.userroomid_leftstate.remove(&userroom_id)?;
|
||||
|
@ -76,21 +67,18 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
|
||||
fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.push(0xFF);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_leftstate.insert(
|
||||
&userroom_id,
|
||||
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
|
||||
)?; // TODO
|
||||
self.roomuserid_leftcount.insert(
|
||||
&roomuser_id,
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
self.roomuserid_leftcount.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
|
||||
self.userroomid_joined.remove(&userroom_id)?;
|
||||
self.roomuserid_joined.remove(&roomuser_id)?;
|
||||
self.userroomid_invitestate.remove(&userroom_id)?;
|
||||
|
@ -105,10 +93,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
let mut joined_servers = HashSet::new();
|
||||
let mut real_users = HashSet::new();
|
||||
|
||||
for joined in self
|
||||
.room_members(room_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
for joined in self.room_members(room_id).filter_map(std::result::Result::ok) {
|
||||
joined_servers.insert(joined.server_name().to_owned());
|
||||
if joined.server_name() == services().globals.server_name()
|
||||
&& !services().users.is_deactivated(&joined).unwrap_or(true)
|
||||
|
@ -118,36 +103,25 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
joinedcount += 1;
|
||||
}
|
||||
|
||||
for _invited in self
|
||||
.room_members_invited(room_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
for _invited in self.room_members_invited(room_id).filter_map(std::result::Result::ok) {
|
||||
invitedcount += 1;
|
||||
}
|
||||
|
||||
self.roomid_joinedcount
|
||||
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
|
||||
self.roomid_joinedcount.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
|
||||
|
||||
self.roomid_invitedcount
|
||||
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
|
||||
self.roomid_invitedcount.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
|
||||
|
||||
self.our_real_users_cache
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(room_id.to_owned(), Arc::new(real_users));
|
||||
self.our_real_users_cache.write().unwrap().insert(room_id.to_owned(), Arc::new(real_users));
|
||||
|
||||
for old_joined_server in self
|
||||
.room_servers(room_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
for old_joined_server in self.room_servers(room_id).filter_map(std::result::Result::ok) {
|
||||
if !joined_servers.remove(&old_joined_server) {
|
||||
// Server not in room anymore
|
||||
let mut roomserver_id = room_id.as_bytes().to_vec();
|
||||
roomserver_id.push(0xff);
|
||||
roomserver_id.push(0xFF);
|
||||
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
|
||||
|
||||
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
|
||||
serverroom_id.push(0xff);
|
||||
serverroom_id.push(0xFF);
|
||||
serverroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.roomserverids.remove(&roomserver_id)?;
|
||||
|
@ -158,70 +132,44 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
// Now only new servers are in joined_servers anymore
|
||||
for server in joined_servers {
|
||||
let mut roomserver_id = room_id.as_bytes().to_vec();
|
||||
roomserver_id.push(0xff);
|
||||
roomserver_id.push(0xFF);
|
||||
roomserver_id.extend_from_slice(server.as_bytes());
|
||||
|
||||
let mut serverroom_id = server.as_bytes().to_vec();
|
||||
serverroom_id.push(0xff);
|
||||
serverroom_id.push(0xFF);
|
||||
serverroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.roomserverids.insert(&roomserver_id, &[])?;
|
||||
self.serverroomids.insert(&serverroom_id, &[])?;
|
||||
}
|
||||
|
||||
self.appservice_in_room_cache
|
||||
.write()
|
||||
.unwrap()
|
||||
.remove(room_id);
|
||||
self.appservice_in_room_cache.write().unwrap().remove(room_id);
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, room_id))]
|
||||
fn get_our_real_users(&self, room_id: &RoomId) -> Result<Arc<HashSet<OwnedUserId>>> {
|
||||
let maybe = self
|
||||
.our_real_users_cache
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(room_id)
|
||||
.cloned();
|
||||
let maybe = self.our_real_users_cache.read().unwrap().get(room_id).cloned();
|
||||
if let Some(users) = maybe {
|
||||
Ok(users)
|
||||
} else {
|
||||
self.update_joined_count(room_id)?;
|
||||
Ok(Arc::clone(
|
||||
self.our_real_users_cache
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(room_id)
|
||||
.unwrap(),
|
||||
))
|
||||
Ok(Arc::clone(self.our_real_users_cache.read().unwrap().get(room_id).unwrap()))
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, room_id, appservice))]
|
||||
fn appservice_in_room(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
appservice: &(String, Registration),
|
||||
) -> Result<bool> {
|
||||
let maybe = self
|
||||
.appservice_in_room_cache
|
||||
.read()
|
||||
.unwrap()
|
||||
.get(room_id)
|
||||
.and_then(|map| map.get(&appservice.0))
|
||||
.copied();
|
||||
fn appservice_in_room(&self, room_id: &RoomId, appservice: &(String, Registration)) -> Result<bool> {
|
||||
let maybe =
|
||||
self.appservice_in_room_cache.read().unwrap().get(room_id).and_then(|map| map.get(&appservice.0)).copied();
|
||||
|
||||
if let Some(b) = maybe {
|
||||
Ok(b)
|
||||
} else {
|
||||
let namespaces = &appservice.1.namespaces;
|
||||
let users = namespaces
|
||||
.users
|
||||
.iter()
|
||||
.filter_map(|users| Regex::new(users.regex.as_str()).ok())
|
||||
.collect::<Vec<_>>();
|
||||
let users =
|
||||
namespaces.users.iter().filter_map(|users| Regex::new(users.regex.as_str()).ok()).collect::<Vec<_>>();
|
||||
|
||||
let bridge_user_id = UserId::parse_with_server_name(
|
||||
appservice.1.sender_localpart.as_str(),
|
||||
|
@ -229,13 +177,10 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
)
|
||||
.ok();
|
||||
|
||||
let in_room = bridge_user_id
|
||||
.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|
||||
|| self.room_members(room_id).any(|userid| {
|
||||
userid.map_or(false, |userid| {
|
||||
users.iter().any(|r| r.is_match(userid.as_str()))
|
||||
})
|
||||
});
|
||||
let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|
||||
|| self
|
||||
.room_members(room_id)
|
||||
.any(|userid| userid.map_or(false, |userid| users.iter().any(|r| r.is_match(userid.as_str()))));
|
||||
|
||||
self.appservice_in_room_cache
|
||||
.write()
|
||||
|
@ -252,11 +197,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
#[tracing::instrument(skip(self))]
|
||||
fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.push(0xFF);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.userroomid_leftstate.remove(&userroom_id)?;
|
||||
|
@ -267,23 +212,14 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
|
||||
/// Returns an iterator of all servers participating in this room.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn room_servers<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
|
||||
fn room_servers<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
|
||||
ServerName::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Server name in roomserverids is invalid unicode.")
|
||||
})?,
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
|
||||
}))
|
||||
|
@ -292,28 +228,22 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
#[tracing::instrument(skip(self))]
|
||||
fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
|
||||
let mut key = server.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.serverroomids.get(&key).map(|o| o.is_some())
|
||||
}
|
||||
|
||||
/// Returns an iterator of all rooms a server participates in (as far as we know).
|
||||
/// Returns an iterator of all rooms a server participates in (as far as we
|
||||
/// know).
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn server_rooms<'a>(
|
||||
&'a self,
|
||||
server: &ServerName,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
fn server_rooms<'a>(&'a self, server: &ServerName) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
let mut prefix = server.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
|
||||
|
@ -322,23 +252,14 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
|
||||
/// Returns an iterator over all joined members of a room.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn room_members<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||
fn room_members<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("User ID in roomuserid_joined is invalid unicode.")
|
||||
})?,
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
|
||||
}))
|
||||
|
@ -348,10 +269,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_joinedcount
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|b| {
|
||||
utils::u64_from_bytes(&b)
|
||||
.map_err(|_| Error::bad_database("Invalid joinedcount in db."))
|
||||
})
|
||||
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
|
@ -359,167 +277,101 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
|
||||
self.roomid_invitedcount
|
||||
.get(room_id.as_bytes())?
|
||||
.map(|b| {
|
||||
utils::u64_from_bytes(&b)
|
||||
.map_err(|_| Error::bad_database("Invalid joinedcount in db."))
|
||||
})
|
||||
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all User IDs who ever joined a room.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn room_useroncejoined<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||
fn room_useroncejoined<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(
|
||||
self.roomuseroncejoinedids
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
Box::new(self.roomuseroncejoinedids.scan_prefix(prefix).map(|(key, _)| {
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database(
|
||||
"User ID in room_useroncejoined is invalid unicode.",
|
||||
)
|
||||
})?,
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
|
||||
}),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
/// Returns an iterator over all invited members of a room.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn room_members_invited<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||
fn room_members_invited<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||
let mut prefix = room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(
|
||||
self.roomuserid_invitecount
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, _)| {
|
||||
Box::new(self.roomuserid_invitecount.scan_prefix(prefix).map(|(key, _)| {
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("User ID in roomuserid_invited is invalid unicode.")
|
||||
})?,
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
|
||||
}),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_invitecount
|
||||
.get(&key)?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid invitecount in db.")
|
||||
})?))
|
||||
self.roomuserid_invitecount.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.roomuserid_leftcount
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid leftcount in db."))
|
||||
})
|
||||
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns an iterator over all rooms this user joined.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn rooms_joined<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(
|
||||
self.userroomid_joined
|
||||
.scan_prefix(user_id.as_bytes().to_vec())
|
||||
.map(|(key, _)| {
|
||||
fn rooms_joined<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
|
||||
Box::new(self.userroomid_joined.scan_prefix(user_id.as_bytes().to_vec()).map(|(key, _)| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Room ID in userroomid_joined is invalid unicode.")
|
||||
})?,
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
|
||||
}),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
/// Returns an iterator over all rooms a user was invited to.
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(
|
||||
self.userroomid_invitestate
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, state)| {
|
||||
Box::new(self.userroomid_invitestate.scan_prefix(prefix).map(|(key, state)| {
|
||||
let room_id = RoomId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Room ID in userroomid_invited is invalid.")
|
||||
})?;
|
||||
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
|
||||
|
||||
let state = serde_json::from_slice(&state).map_err(|_| {
|
||||
Error::bad_database("Invalid state in userroomid_invitestate.")
|
||||
})?;
|
||||
let state = serde_json::from_slice(&state)
|
||||
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
|
||||
|
||||
Ok((room_id, state))
|
||||
}),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn invite_state(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
|
||||
fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_invitestate
|
||||
|
@ -534,13 +386,9 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn left_state(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
|
||||
fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_leftstate
|
||||
|
@ -558,39 +406,26 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
#[tracing::instrument(skip(self))]
|
||||
fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
Box::new(
|
||||
self.userroomid_leftstate
|
||||
.scan_prefix(prefix)
|
||||
.map(|(key, state)| {
|
||||
Box::new(self.userroomid_leftstate.scan_prefix(prefix).map(|(key, state)| {
|
||||
let room_id = RoomId::parse(
|
||||
utils::string_from_bytes(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
.next()
|
||||
.expect("rsplit always returns an element"),
|
||||
utils::string_from_bytes(key.rsplit(|&b| b == 0xFF).next().expect("rsplit always returns an element"))
|
||||
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Room ID in userroomid_invited is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Room ID in userroomid_invited is invalid.")
|
||||
})?;
|
||||
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
|
||||
|
||||
let state = serde_json::from_slice(&state).map_err(|_| {
|
||||
Error::bad_database("Invalid state in userroomid_leftstate.")
|
||||
})?;
|
||||
let state = serde_json::from_slice(&state)
|
||||
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
|
||||
|
||||
Ok((room_id, state))
|
||||
}),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
|
||||
|
@ -599,7 +434,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
#[tracing::instrument(skip(self))]
|
||||
fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
|
||||
|
@ -608,7 +443,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
#[tracing::instrument(skip(self))]
|
||||
fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
|
||||
|
@ -617,7 +452,7 @@ impl service::rooms::state_cache::Data for KeyValueDatabase {
|
|||
#[tracing::instrument(skip(self))]
|
||||
fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
|
||||
|
|
|
@ -12,9 +12,12 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
|
|||
.shortstatehash_statediff
|
||||
.get(&shortstatehash.to_be_bytes())?
|
||||
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
|
||||
let parent =
|
||||
utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
|
||||
let parent = if parent != 0 { Some(parent) } else { None };
|
||||
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
|
||||
let parent = if parent != 0 {
|
||||
Some(parent)
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
let mut add_mode = true;
|
||||
let mut added = HashSet::new();
|
||||
|
@ -55,7 +58,6 @@ impl service::rooms::state_compressor::Data for KeyValueDatabase {
|
|||
}
|
||||
}
|
||||
|
||||
self.shortstatehash_statediff
|
||||
.insert(&shortstatehash.to_be_bytes(), &value)
|
||||
self.shortstatehash_statediff.insert(&shortstatehash.to_be_bytes(), &value)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -8,51 +8,34 @@ type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEve
|
|||
|
||||
impl service::rooms::threads::Data for KeyValueDatabase {
|
||||
fn threads_until<'a>(
|
||||
&'a self,
|
||||
user_id: &'a UserId,
|
||||
room_id: &'a RoomId,
|
||||
until: u64,
|
||||
_include: &'a IncludeThreads,
|
||||
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
|
||||
) -> PduEventIterResult<'a> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists")
|
||||
.to_be_bytes()
|
||||
.to_vec();
|
||||
let prefix = services().rooms.short.get_shortroomid(room_id)?.expect("room exists").to_be_bytes().to_vec();
|
||||
|
||||
let mut current = prefix.clone();
|
||||
current.extend_from_slice(&(until - 1).to_be_bytes());
|
||||
|
||||
Ok(Box::new(
|
||||
self.threadid_userids
|
||||
.iter_from(¤t, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pduid, _users)| {
|
||||
self.threadid_userids.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(pduid, _users)| {
|
||||
let count = utils::u64_from_bytes(&pduid[(mem::size_of::<u64>())..])
|
||||
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
|
||||
let mut pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu_from_id(&pduid)?
|
||||
.ok_or_else(|| {
|
||||
Error::bad_database("Invalid pduid reference in threadid_userids")
|
||||
})?;
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
|
||||
if pdu.sender != user_id {
|
||||
pdu.remove_transaction_id()?;
|
||||
}
|
||||
Ok((count, pdu))
|
||||
}),
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
|
||||
let users = participants
|
||||
.iter()
|
||||
.map(|user| user.as_bytes())
|
||||
.collect::<Vec<_>>()
|
||||
.join(&[0xff][..]);
|
||||
let users = participants.iter().map(|user| user.as_bytes()).collect::<Vec<_>>().join(&[0xFF][..]);
|
||||
|
||||
self.threadid_userids.insert(root_id, &users)?;
|
||||
|
||||
|
@ -63,11 +46,12 @@ impl service::rooms::threads::Data for KeyValueDatabase {
|
|||
if let Some(users) = self.threadid_userids.get(root_id)? {
|
||||
Ok(Some(
|
||||
users
|
||||
.split(|b| *b == 0xff)
|
||||
.split(|b| *b == 0xFF)
|
||||
.map(|bytes| {
|
||||
UserId::parse(utils::string_from_bytes(bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid UserId bytes in threadid_userids.")
|
||||
})?)
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
|
||||
})
|
||||
.filter_map(std::result::Result::ok)
|
||||
|
|
|
@ -1,48 +1,34 @@
|
|||
use std::{collections::hash_map, mem::size_of, sync::Arc};
|
||||
|
||||
use ruma::{
|
||||
api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedUserId, RoomId, UserId};
|
||||
use service::rooms::timeline::PduCount;
|
||||
use tracing::error;
|
||||
|
||||
use crate::{database::KeyValueDatabase, service, services, utils, Error, PduEvent, Result};
|
||||
|
||||
use service::rooms::timeline::PduCount;
|
||||
|
||||
impl service::rooms::timeline::Data for KeyValueDatabase {
|
||||
fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
|
||||
match self
|
||||
.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.entry(room_id.to_owned())
|
||||
{
|
||||
match self.lasttimelinecount_cache.lock().unwrap().entry(room_id.to_owned()) {
|
||||
hash_map::Entry::Vacant(v) => {
|
||||
if let Some(last_count) = self
|
||||
.pdus_until(sender_user, room_id, PduCount::max())?
|
||||
.find_map(|r| {
|
||||
if let Some(last_count) = self.pdus_until(sender_user, room_id, PduCount::max())?.find_map(|r| {
|
||||
// Filter out buggy events
|
||||
if r.is_err() {
|
||||
error!("Bad pdu in pdus_since: {:?}", r);
|
||||
}
|
||||
r.ok()
|
||||
})
|
||||
{
|
||||
}) {
|
||||
Ok(*v.insert(last_count.0))
|
||||
} else {
|
||||
Ok(PduCount::Normal(0))
|
||||
}
|
||||
}
|
||||
},
|
||||
hash_map::Entry::Occupied(o) => Ok(*o.get()),
|
||||
}
|
||||
}
|
||||
|
||||
/// Returns the `count` of this pdu's id.
|
||||
fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu_id| pdu_count(&pdu_id))
|
||||
.transpose()
|
||||
self.eventid_pduid.get(event_id.as_bytes())?.map(|pdu_id| pdu_count(&pdu_id)).transpose()
|
||||
}
|
||||
|
||||
/// Returns the json of a pdu.
|
||||
|
@ -51,10 +37,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
|||
|| {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
},
|
||||
|x| Ok(Some(x)),
|
||||
|
@ -66,35 +49,25 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
|||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu
|
||||
.get(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
/// Returns the pdu's id.
|
||||
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> {
|
||||
self.eventid_pduid.get(event_id.as_bytes())
|
||||
}
|
||||
fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<Vec<u8>>> { self.eventid_pduid.get(event_id.as_bytes()) }
|
||||
|
||||
/// Returns the pdu.
|
||||
fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
|
||||
self.eventid_pduid
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pduid| {
|
||||
self.pduid_pdu
|
||||
.get(&pduid)?
|
||||
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
self.pduid_pdu.get(&pduid)?.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
|
||||
})
|
||||
.transpose()?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
}
|
||||
|
||||
|
@ -112,20 +85,14 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
|||
|| {
|
||||
self.eventid_outlierpdu
|
||||
.get(event_id.as_bytes())?
|
||||
.map(|pdu| {
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))
|
||||
})
|
||||
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
|
||||
.transpose()
|
||||
},
|
||||
|x| Ok(Some(x)),
|
||||
)?
|
||||
.map(Arc::new)
|
||||
{
|
||||
self.pdu_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(event_id.to_owned(), Arc::clone(&pdu));
|
||||
self.pdu_cache.lock().unwrap().insert(event_id.to_owned(), Arc::clone(&pdu));
|
||||
Ok(Some(pdu))
|
||||
} else {
|
||||
Ok(None)
|
||||
|
@ -138,8 +105,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
|||
fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
@ -148,28 +114,18 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
|||
fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
|
||||
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&pdu)
|
||||
.map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn append_pdu(
|
||||
&self,
|
||||
pdu_id: &[u8],
|
||||
pdu: &PduEvent,
|
||||
json: &CanonicalJsonObject,
|
||||
count: u64,
|
||||
) -> Result<()> {
|
||||
fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
|
||||
self.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.insert(pdu.room_id.clone(), PduCount::Normal(count));
|
||||
self.lasttimelinecount_cache.lock().unwrap().insert(pdu.room_id.clone(), PduCount::Normal(count));
|
||||
|
||||
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
|
||||
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
|
||||
|
@ -177,12 +133,7 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn prepend_backfill_pdu(
|
||||
&self,
|
||||
pdu_id: &[u8],
|
||||
event_id: &EventId,
|
||||
json: &CanonicalJsonObject,
|
||||
) -> Result<()> {
|
||||
fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) -> Result<()> {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
|
||||
|
@ -195,49 +146,34 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
/// Removes a pdu and creates a new one with the same id.
|
||||
fn replace_pdu(
|
||||
&self,
|
||||
pdu_id: &[u8],
|
||||
pdu_json: &CanonicalJsonObject,
|
||||
pdu: &PduEvent,
|
||||
) -> Result<()> {
|
||||
fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> {
|
||||
if self.pduid_pdu.get(pdu_id)?.is_some() {
|
||||
self.pduid_pdu.insert(
|
||||
pdu_id,
|
||||
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
|
||||
)?;
|
||||
} else {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::NotFound,
|
||||
"PDU does not exist.",
|
||||
));
|
||||
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
|
||||
}
|
||||
|
||||
self.pdu_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.remove(&(*pdu.event_id).to_owned());
|
||||
self.pdu_cache.lock().unwrap().remove(&(*pdu.event_id).to_owned());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Returns an iterator over all events and their tokens in a room that happened before the
|
||||
/// event with id `until` in reverse-chronological order.
|
||||
/// Returns an iterator over all events and their tokens in a room that
|
||||
/// happened before the event with id `until` in reverse-chronological
|
||||
/// order.
|
||||
fn pdus_until<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
until: PduCount,
|
||||
&'a self, user_id: &UserId, room_id: &RoomId, until: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let (prefix, current) = count_to_id(room_id, until, 1, true)?;
|
||||
|
||||
let user_id = user_id.to_owned();
|
||||
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu
|
||||
.iter_from(¤t, true)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pdu_id, v)| {
|
||||
self.pduid_pdu.iter_from(¤t, true).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
|
@ -246,25 +182,21 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
|||
pdu.add_age()?;
|
||||
let count = pdu_count(&pdu_id)?;
|
||||
Ok((count, pdu))
|
||||
}),
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn pdus_after<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
from: PduCount,
|
||||
&'a self, user_id: &UserId, room_id: &RoomId, from: PduCount,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<(PduCount, PduEvent)>> + 'a>> {
|
||||
let (prefix, current) = count_to_id(room_id, from, 1, false)?;
|
||||
|
||||
let user_id = user_id.to_owned();
|
||||
|
||||
Ok(Box::new(
|
||||
self.pduid_pdu
|
||||
.iter_from(¤t, false)
|
||||
.take_while(move |(k, _)| k.starts_with(&prefix))
|
||||
.map(move |(pdu_id, v)| {
|
||||
self.pduid_pdu.iter_from(¤t, false).take_while(move |(k, _)| k.starts_with(&prefix)).map(
|
||||
move |(pdu_id, v)| {
|
||||
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
|
||||
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
|
||||
if pdu.sender != user_id {
|
||||
|
@ -273,35 +205,31 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
|||
pdu.add_age()?;
|
||||
let count = pdu_count(&pdu_id)?;
|
||||
Ok((count, pdu))
|
||||
}),
|
||||
},
|
||||
),
|
||||
))
|
||||
}
|
||||
|
||||
fn increment_notification_counts(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
notifies: Vec<OwnedUserId>,
|
||||
highlights: Vec<OwnedUserId>,
|
||||
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
|
||||
) -> Result<()> {
|
||||
let mut notifies_batch = Vec::new();
|
||||
let mut highlights_batch = Vec::new();
|
||||
for user in notifies {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
notifies_batch.push(userroom_id);
|
||||
}
|
||||
for user in highlights {
|
||||
let mut userroom_id = user.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
highlights_batch.push(userroom_id);
|
||||
}
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.increment_batch(&mut notifies_batch.into_iter())?;
|
||||
self.userroomid_highlightcount
|
||||
.increment_batch(&mut highlights_batch.into_iter())?;
|
||||
self.userroomid_notificationcount.increment_batch(&mut notifies_batch.into_iter())?;
|
||||
self.userroomid_highlightcount.increment_batch(&mut highlights_batch.into_iter())?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
@ -310,9 +238,8 @@ impl service::rooms::timeline::Data for KeyValueDatabase {
|
|||
fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
|
||||
let last_u64 = utils::u64_from_bytes(&pdu_id[pdu_id.len() - size_of::<u64>()..])
|
||||
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
|
||||
let second_last_u64 = utils::u64_from_bytes(
|
||||
&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()],
|
||||
);
|
||||
let second_last_u64 =
|
||||
utils::u64_from_bytes(&pdu_id[pdu_id.len() - 2 * size_of::<u64>()..pdu_id.len() - size_of::<u64>()]);
|
||||
|
||||
if matches!(second_last_u64, Ok(0)) {
|
||||
Ok(PduCount::Backfilled(u64::MAX - last_u64))
|
||||
|
@ -321,12 +248,7 @@ fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
|
|||
}
|
||||
}
|
||||
|
||||
fn count_to_id(
|
||||
room_id: &RoomId,
|
||||
count: PduCount,
|
||||
offset: u64,
|
||||
subtract: bool,
|
||||
) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
fn count_to_id(room_id: &RoomId, count: PduCount, offset: u64, subtract: bool) -> Result<(Vec<u8>, Vec<u8>)> {
|
||||
let prefix = services()
|
||||
.rooms
|
||||
.short
|
||||
|
@ -343,7 +265,7 @@ fn count_to_id(
|
|||
} else {
|
||||
x + offset
|
||||
}
|
||||
}
|
||||
},
|
||||
PduCount::Backfilled(x) => {
|
||||
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
|
||||
let num = u64::MAX - x;
|
||||
|
@ -356,7 +278,7 @@ fn count_to_id(
|
|||
} else {
|
||||
num + offset
|
||||
}
|
||||
}
|
||||
},
|
||||
};
|
||||
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
|
||||
|
||||
|
|
|
@ -5,95 +5,73 @@ use crate::{database::KeyValueDatabase, service, services, utils, Error, Result}
|
|||
impl service::rooms::user::Data for KeyValueDatabase {
|
||||
fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
let mut roomuser_id = room_id.as_bytes().to_vec();
|
||||
roomuser_id.push(0xff);
|
||||
roomuser_id.push(0xFF);
|
||||
roomuser_id.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
self.userroomid_highlightcount
|
||||
.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
self.userroomid_notificationcount.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
self.userroomid_highlightcount.insert(&userroom_id, &0_u64.to_be_bytes())?;
|
||||
|
||||
self.roomuserid_lastnotificationread.insert(
|
||||
&roomuser_id,
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
self.roomuserid_lastnotificationread.insert(&roomuser_id, &services().globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_notificationcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid notification count in db."))
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut userroom_id = user_id.as_bytes().to_vec();
|
||||
userroom_id.push(0xff);
|
||||
userroom_id.push(0xFF);
|
||||
userroom_id.extend_from_slice(room_id.as_bytes());
|
||||
|
||||
self.userroomid_highlightcount
|
||||
.get(&userroom_id)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid highlight count in db."))
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id.as_bytes());
|
||||
|
||||
Ok(self
|
||||
.roomuserid_lastnotificationread
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid.")
|
||||
})
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
|
||||
})
|
||||
.transpose()?
|
||||
.unwrap_or(0))
|
||||
}
|
||||
|
||||
fn associate_token_shortstatehash(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
token: u64,
|
||||
shortstatehash: u64,
|
||||
) -> Result<()> {
|
||||
let shortroomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists");
|
||||
fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
|
||||
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
|
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
|
||||
self.roomsynctoken_shortstatehash
|
||||
.insert(&key, &shortstatehash.to_be_bytes())
|
||||
self.roomsynctoken_shortstatehash.insert(&key, &shortstatehash.to_be_bytes())
|
||||
}
|
||||
|
||||
fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
|
||||
let shortroomid = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_shortroomid(room_id)?
|
||||
.expect("room exists");
|
||||
let shortroomid = services().rooms.short.get_shortroomid(room_id)?.expect("room exists");
|
||||
|
||||
let mut key = shortroomid.to_be_bytes().to_vec();
|
||||
key.extend_from_slice(&token.to_be_bytes());
|
||||
|
@ -101,20 +79,18 @@ impl service::rooms::user::Data for KeyValueDatabase {
|
|||
self.roomsynctoken_shortstatehash
|
||||
.get(&key)?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash")
|
||||
})
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
|
||||
})
|
||||
.transpose()
|
||||
}
|
||||
|
||||
fn get_shared_rooms<'a>(
|
||||
&'a self,
|
||||
users: Vec<OwnedUserId>,
|
||||
&'a self, users: Vec<OwnedUserId>,
|
||||
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
|
||||
let iterators = users.into_iter().map(move |user_id| {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
self.userroomid_joined
|
||||
.scan_prefix(prefix)
|
||||
|
@ -122,10 +98,9 @@ impl service::rooms::user::Data for KeyValueDatabase {
|
|||
let roomid_index = key
|
||||
.iter()
|
||||
.enumerate()
|
||||
.find(|(_, &b)| b == 0xff)
|
||||
.find(|(_, &b)| b == 0xFF)
|
||||
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
|
||||
.0
|
||||
+ 1; // +1 because the room id starts AFTER the separator
|
||||
.0 + 1; // +1 because the room id starts AFTER the separator
|
||||
|
||||
let room_id = key[roomid_index..].to_vec();
|
||||
|
||||
|
@ -134,14 +109,14 @@ impl service::rooms::user::Data for KeyValueDatabase {
|
|||
.filter_map(std::result::Result::ok)
|
||||
});
|
||||
|
||||
// We use the default compare function because keys are sorted correctly (not reversed)
|
||||
// We use the default compare function because keys are sorted correctly (not
|
||||
// reversed)
|
||||
Ok(Box::new(
|
||||
utils::common_elements(iterators, Ord::cmp)
|
||||
.expect("users is not empty")
|
||||
.map(|bytes| {
|
||||
RoomId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Invalid RoomId bytes in userroomid_joined")
|
||||
})?)
|
||||
utils::common_elements(iterators, Ord::cmp).expect("users is not empty").map(|bytes| {
|
||||
RoomId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
|
||||
}),
|
||||
))
|
||||
|
|
|
@ -21,8 +21,7 @@ impl service::sending::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn active_requests_for<'a>(
|
||||
&'a self,
|
||||
outgoing_kind: &OutgoingKind,
|
||||
&'a self, outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(Vec<u8>, SendingEventType)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
Box::new(
|
||||
|
@ -32,9 +31,7 @@ impl service::sending::Data for KeyValueDatabase {
|
|||
)
|
||||
}
|
||||
|
||||
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> {
|
||||
self.servercurrentevent_data.remove(&key)
|
||||
}
|
||||
fn delete_active_request(&self, key: Vec<u8>) -> Result<()> { self.servercurrentevent_data.remove(&key) }
|
||||
|
||||
fn delete_all_active_requests_for(&self, outgoing_kind: &OutgoingKind) -> Result<()> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
|
@ -58,10 +55,7 @@ impl service::sending::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn queue_requests(
|
||||
&self,
|
||||
requests: &[(&OutgoingKind, SendingEventType)],
|
||||
) -> Result<Vec<Vec<u8>>> {
|
||||
fn queue_requests(&self, requests: &[(&OutgoingKind, SendingEventType)]) -> Result<Vec<Vec<u8>>> {
|
||||
let mut batch = Vec::new();
|
||||
let mut keys = Vec::new();
|
||||
for (outgoing_kind, event) in requests {
|
||||
|
@ -79,14 +73,12 @@ impl service::sending::Data for KeyValueDatabase {
|
|||
batch.push((key.clone(), value.to_owned()));
|
||||
keys.push(key);
|
||||
}
|
||||
self.servernameevent_data
|
||||
.insert_batch(&mut batch.into_iter())?;
|
||||
self.servernameevent_data.insert_batch(&mut batch.into_iter())?;
|
||||
Ok(keys)
|
||||
}
|
||||
|
||||
fn queued_requests<'a>(
|
||||
&'a self,
|
||||
outgoing_kind: &OutgoingKind,
|
||||
&'a self, outgoing_kind: &OutgoingKind,
|
||||
) -> Box<dyn Iterator<Item = Result<(SendingEventType, Vec<u8>)>> + 'a> {
|
||||
let prefix = outgoing_kind.get_prefix();
|
||||
return Box::new(
|
||||
|
@ -111,37 +103,27 @@ impl service::sending::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn set_latest_educount(&self, server_name: &ServerName, last_count: u64) -> Result<()> {
|
||||
self.servername_educount
|
||||
.insert(server_name.as_bytes(), &last_count.to_be_bytes())
|
||||
self.servername_educount.insert(server_name.as_bytes(), &last_count.to_be_bytes())
|
||||
}
|
||||
|
||||
fn get_latest_educount(&self, server_name: &ServerName) -> Result<u64> {
|
||||
self.servername_educount
|
||||
.get(server_name.as_bytes())?
|
||||
.map_or(Ok(0), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
|
||||
self.servername_educount.get(server_name.as_bytes())?.map_or(Ok(0), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid u64 in servername_educount."))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(key))]
|
||||
fn parse_servercurrentevent(
|
||||
key: &[u8],
|
||||
value: Vec<u8>,
|
||||
) -> Result<(OutgoingKind, SendingEventType)> {
|
||||
fn parse_servercurrentevent(key: &[u8], value: Vec<u8>) -> Result<(OutgoingKind, SendingEventType)> {
|
||||
// Appservices start with a plus
|
||||
Ok::<_, Error>(if key.starts_with(b"+") {
|
||||
let mut parts = key[1..].splitn(2, |&b| b == 0xff);
|
||||
let mut parts = key[1..].splitn(2, |&b| b == 0xFF);
|
||||
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
let server = utils::string_from_bytes(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
|
||||
|
||||
(
|
||||
OutgoingKind::Appservice(server),
|
||||
|
@ -152,23 +134,19 @@ fn parse_servercurrentevent(
|
|||
},
|
||||
)
|
||||
} else if key.starts_with(b"$") {
|
||||
let mut parts = key[1..].splitn(3, |&b| b == 0xff);
|
||||
let mut parts = key[1..].splitn(3, |&b| b == 0xFF);
|
||||
|
||||
let user = parts.next().expect("splitn always returns one element");
|
||||
let user_string = utils::string_from_bytes(user)
|
||||
.map_err(|_| Error::bad_database("Invalid user string in servercurrentevent"))?;
|
||||
let user_id = UserId::parse(user_string)
|
||||
.map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
||||
let user_id =
|
||||
UserId::parse(user_string).map_err(|_| Error::bad_database("Invalid user id in servercurrentevent"))?;
|
||||
|
||||
let pushkey = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let pushkey = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let pushkey_string = utils::string_from_bytes(pushkey)
|
||||
.map_err(|_| Error::bad_database("Invalid pushkey in servercurrentevent"))?;
|
||||
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
(
|
||||
OutgoingKind::Push(user_id, pushkey_string),
|
||||
|
@ -180,21 +158,19 @@ fn parse_servercurrentevent(
|
|||
},
|
||||
)
|
||||
} else {
|
||||
let mut parts = key.splitn(2, |&b| b == 0xff);
|
||||
let mut parts = key.splitn(2, |&b| b == 0xFF);
|
||||
|
||||
let server = parts.next().expect("splitn always returns one element");
|
||||
let event = parts
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
let event = parts.next().ok_or_else(|| Error::bad_database("Invalid bytes in servercurrentpdus."))?;
|
||||
|
||||
let server = utils::string_from_bytes(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server bytes in server_currenttransaction")
|
||||
})?;
|
||||
let server = utils::string_from_bytes(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server bytes in server_currenttransaction"))?;
|
||||
|
||||
(
|
||||
OutgoingKind::Normal(ServerName::parse(server).map_err(|_| {
|
||||
Error::bad_database("Invalid server string in server_currenttransaction")
|
||||
})?),
|
||||
OutgoingKind::Normal(
|
||||
ServerName::parse(server)
|
||||
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
|
||||
),
|
||||
if value.is_empty() {
|
||||
SendingEventType::Pdu(event.to_vec())
|
||||
} else {
|
||||
|
|
|
@ -4,16 +4,12 @@ use crate::{database::KeyValueDatabase, service, Result};
|
|||
|
||||
impl service::transaction_ids::Data for KeyValueDatabase {
|
||||
fn add_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &TransactionId,
|
||||
data: &[u8],
|
||||
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId, data: &[u8],
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
|
||||
self.userdevicetxnid_response.insert(&key, data)?;
|
||||
|
@ -22,15 +18,12 @@ impl service::transaction_ids::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn existing_txnid(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: Option<&DeviceId>,
|
||||
txn_id: &TransactionId,
|
||||
&self, user_id: &UserId, device_id: Option<&DeviceId>, txn_id: &TransactionId,
|
||||
) -> Result<Option<Vec<u8>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.map(DeviceId::as_bytes).unwrap_or_default());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(txn_id.as_bytes());
|
||||
|
||||
// If there's no entry, this is a new transaction
|
||||
|
|
|
@ -7,16 +7,9 @@ use crate::{database::KeyValueDatabase, service, Error, Result};
|
|||
|
||||
impl service::uiaa::Data for KeyValueDatabase {
|
||||
fn set_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
request: &CanonicalJsonValue,
|
||||
&self, user_id: &UserId, device_id: &DeviceId, session: &str, request: &CanonicalJsonValue,
|
||||
) -> Result<()> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.write()
|
||||
.unwrap()
|
||||
.insert(
|
||||
self.userdevicesessionid_uiaarequest.write().unwrap().insert(
|
||||
(user_id.to_owned(), device_id.to_owned(), session.to_owned()),
|
||||
request.to_owned(),
|
||||
);
|
||||
|
@ -24,12 +17,7 @@ impl service::uiaa::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn get_uiaa_request(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Option<CanonicalJsonValue> {
|
||||
fn get_uiaa_request(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Option<CanonicalJsonValue> {
|
||||
self.userdevicesessionid_uiaarequest
|
||||
.read()
|
||||
.unwrap()
|
||||
|
@ -38,16 +26,12 @@ impl service::uiaa::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn update_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
uiaainfo: Option<&UiaaInfo>,
|
||||
&self, user_id: &UserId, device_id: &DeviceId, session: &str, uiaainfo: Option<&UiaaInfo>,
|
||||
) -> Result<()> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
|
||||
if let Some(uiaainfo) = uiaainfo {
|
||||
|
@ -56,33 +40,24 @@ impl service::uiaa::Data for KeyValueDatabase {
|
|||
&serde_json::to_vec(&uiaainfo).expect("UiaaInfo::to_vec always works"),
|
||||
)?;
|
||||
} else {
|
||||
self.userdevicesessionid_uiaainfo
|
||||
.remove(&userdevicesessionid)?;
|
||||
self.userdevicesessionid_uiaainfo.remove(&userdevicesessionid)?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_uiaa_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
session: &str,
|
||||
) -> Result<UiaaInfo> {
|
||||
fn get_uiaa_session(&self, user_id: &UserId, device_id: &DeviceId, session: &str) -> Result<UiaaInfo> {
|
||||
let mut userdevicesessionid = user_id.as_bytes().to_vec();
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(device_id.as_bytes());
|
||||
userdevicesessionid.push(0xff);
|
||||
userdevicesessionid.push(0xFF);
|
||||
userdevicesessionid.extend_from_slice(session.as_bytes());
|
||||
|
||||
serde_json::from_slice(
|
||||
&self
|
||||
.userdevicesessionid_uiaainfo
|
||||
.get(&userdevicesessionid)?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::Forbidden,
|
||||
"UIAA session does not exist.",
|
||||
))?,
|
||||
.ok_or(Error::BadRequest(ErrorKind::Forbidden, "UIAA session does not exist."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("UiaaInfo in userdeviceid_uiaainfo is invalid."))
|
||||
}
|
||||
|
|
|
@ -5,8 +5,8 @@ use ruma::{
|
|||
encryption::{CrossSigningKey, DeviceKeys, OneTimeKey},
|
||||
events::{AnyToDeviceEvent, StateEventType},
|
||||
serde::Raw,
|
||||
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId,
|
||||
OwnedDeviceKeyId, OwnedMxcUri, OwnedUserId, UInt, UserId,
|
||||
DeviceId, DeviceKeyAlgorithm, DeviceKeyId, MilliSecondsSinceUnixEpoch, OwnedDeviceId, OwnedDeviceKeyId,
|
||||
OwnedMxcUri, OwnedUserId, UInt, UserId,
|
||||
};
|
||||
use tracing::warn;
|
||||
|
||||
|
@ -18,50 +18,37 @@ use crate::{
|
|||
|
||||
impl service::users::Data for KeyValueDatabase {
|
||||
/// Check if a user has an account on this homeserver.
|
||||
fn exists(&self, user_id: &UserId) -> Result<bool> {
|
||||
Ok(self.userid_password.get(user_id.as_bytes())?.is_some())
|
||||
}
|
||||
fn exists(&self, user_id: &UserId) -> Result<bool> { Ok(self.userid_password.get(user_id.as_bytes())?.is_some()) }
|
||||
|
||||
/// Check if account is deactivated
|
||||
fn is_deactivated(&self, user_id: &UserId) -> Result<bool> {
|
||||
Ok(self
|
||||
.userid_password
|
||||
.get(user_id.as_bytes())?
|
||||
.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"User does not exist.",
|
||||
))?
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."))?
|
||||
.is_empty())
|
||||
}
|
||||
|
||||
/// Returns the number of users registered on this server.
|
||||
fn count(&self) -> Result<usize> {
|
||||
Ok(self.userid_password.iter().count())
|
||||
}
|
||||
fn count(&self) -> Result<usize> { Ok(self.userid_password.iter().count()) }
|
||||
|
||||
/// Find out which user an access token belongs to.
|
||||
fn find_from_token(&self, token: &str) -> Result<Option<(OwnedUserId, String)>> {
|
||||
self.token_userdeviceid
|
||||
.get(token.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
let mut parts = bytes.split(|&b| b == 0xff);
|
||||
let user_bytes = parts.next().ok_or_else(|| {
|
||||
Error::bad_database("User ID in token_userdeviceid is invalid.")
|
||||
})?;
|
||||
let device_bytes = parts.next().ok_or_else(|| {
|
||||
Error::bad_database("Device ID in token_userdeviceid is invalid.")
|
||||
})?;
|
||||
self.token_userdeviceid.get(token.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||
let mut parts = bytes.split(|&b| b == 0xFF);
|
||||
let user_bytes =
|
||||
parts.next().ok_or_else(|| Error::bad_database("User ID in token_userdeviceid is invalid."))?;
|
||||
let device_bytes =
|
||||
parts.next().ok_or_else(|| Error::bad_database("Device ID in token_userdeviceid is invalid."))?;
|
||||
|
||||
Ok(Some((
|
||||
UserId::parse(utils::string_from_bytes(user_bytes).map_err(|_| {
|
||||
Error::bad_database("User ID in token_userdeviceid is invalid unicode.")
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("User ID in token_userdeviceid is invalid.")
|
||||
})?,
|
||||
utils::string_from_bytes(device_bytes).map_err(|_| {
|
||||
Error::bad_database("Device ID in token_userdeviceid is invalid.")
|
||||
})?,
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(user_bytes)
|
||||
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in token_userdeviceid is invalid."))?,
|
||||
utils::string_from_bytes(device_bytes)
|
||||
.map_err(|_| Error::bad_database("Device ID in token_userdeviceid is invalid."))?,
|
||||
)))
|
||||
})
|
||||
}
|
||||
|
@ -69,16 +56,18 @@ impl service::users::Data for KeyValueDatabase {
|
|||
/// Returns an iterator over all users on this homeserver.
|
||||
fn iter<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||
Box::new(self.userid_password.iter().map(|(bytes, _)| {
|
||||
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("User ID in userid_password is invalid unicode.")
|
||||
})?)
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("User ID in userid_password is invalid unicode."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("User ID in userid_password is invalid."))
|
||||
}))
|
||||
}
|
||||
|
||||
/// Returns a list of local users as list of usernames.
|
||||
///
|
||||
/// A user account is considered `local` if the length of it's password is greater then zero.
|
||||
/// A user account is considered `local` if the length of it's password is
|
||||
/// greater then zero.
|
||||
fn list_local_users(&self) -> Result<Vec<String>> {
|
||||
let users: Vec<String> = self
|
||||
.userid_password
|
||||
|
@ -90,9 +79,7 @@ impl service::users::Data for KeyValueDatabase {
|
|||
|
||||
/// Returns the password hash for the given user.
|
||||
fn password_hash(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.userid_password
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
self.userid_password.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Password hash in db is not valid string.")
|
||||
})?))
|
||||
|
@ -103,8 +90,7 @@ impl service::users::Data for KeyValueDatabase {
|
|||
fn set_password(&self, user_id: &UserId, password: Option<&str>) -> Result<()> {
|
||||
if let Some(password) = password {
|
||||
if let Ok(hash) = utils::calculate_password_hash(password) {
|
||||
self.userid_password
|
||||
.insert(user_id.as_bytes(), hash.as_bytes())?;
|
||||
self.userid_password.insert(user_id.as_bytes(), hash.as_bytes())?;
|
||||
Ok(())
|
||||
} else {
|
||||
Err(Error::BadRequest(
|
||||
|
@ -120,20 +106,18 @@ impl service::users::Data for KeyValueDatabase {
|
|||
|
||||
/// Returns the displayname of a user on this homeserver.
|
||||
fn displayname(&self, user_id: &UserId) -> Result<Option<String>> {
|
||||
self.userid_displayname
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Displayname in db is invalid.")
|
||||
})?))
|
||||
self.userid_displayname.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(
|
||||
utils::string_from_bytes(&bytes).map_err(|_| Error::bad_database("Displayname in db is invalid."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
/// Sets a new displayname or removes it if displayname is None. You still need to nofify all rooms of this change.
|
||||
/// Sets a new displayname or removes it if displayname is None. You still
|
||||
/// need to nofify all rooms of this change.
|
||||
fn set_displayname(&self, user_id: &UserId, displayname: Option<String>) -> Result<()> {
|
||||
if let Some(displayname) = displayname {
|
||||
self.userid_displayname
|
||||
.insert(user_id.as_bytes(), displayname.as_bytes())?;
|
||||
self.userid_displayname.insert(user_id.as_bytes(), displayname.as_bytes())?;
|
||||
} else {
|
||||
self.userid_displayname.remove(user_id.as_bytes())?;
|
||||
}
|
||||
|
@ -159,8 +143,7 @@ impl service::users::Data for KeyValueDatabase {
|
|||
/// Sets a new avatar_url or removes it if avatar_url is None.
|
||||
fn set_avatar_url(&self, user_id: &UserId, avatar_url: Option<OwnedMxcUri>) -> Result<()> {
|
||||
if let Some(avatar_url) = avatar_url {
|
||||
self.userid_avatarurl
|
||||
.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
|
||||
self.userid_avatarurl.insert(user_id.as_bytes(), avatar_url.to_string().as_bytes())?;
|
||||
} else {
|
||||
self.userid_avatarurl.remove(user_id.as_bytes())?;
|
||||
}
|
||||
|
@ -184,8 +167,7 @@ impl service::users::Data for KeyValueDatabase {
|
|||
/// Sets a new avatar_url or removes it if avatar_url is None.
|
||||
fn set_blurhash(&self, user_id: &UserId, blurhash: Option<String>) -> Result<()> {
|
||||
if let Some(blurhash) = blurhash {
|
||||
self.userid_blurhash
|
||||
.insert(user_id.as_bytes(), blurhash.as_bytes())?;
|
||||
self.userid_blurhash.insert(user_id.as_bytes(), blurhash.as_bytes())?;
|
||||
} else {
|
||||
self.userid_blurhash.remove(user_id.as_bytes())?;
|
||||
}
|
||||
|
@ -195,30 +177,20 @@ impl service::users::Data for KeyValueDatabase {
|
|||
|
||||
/// Adds a new device to a user.
|
||||
fn create_device(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
token: &str,
|
||||
initial_device_display_name: Option<String>,
|
||||
&self, user_id: &UserId, device_id: &DeviceId, token: &str, initial_device_display_name: Option<String>,
|
||||
) -> Result<()> {
|
||||
// This method should never be called for nonexistent users. We shouldn't assert though...
|
||||
// This method should never be called for nonexistent users. We shouldn't assert
|
||||
// though...
|
||||
if !self.exists(user_id)? {
|
||||
warn!(
|
||||
"Called create_device for non-existent user {} in database",
|
||||
user_id
|
||||
);
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"User does not exist.",
|
||||
));
|
||||
warn!("Called create_device for non-existent user {} in database", user_id);
|
||||
return Err(Error::BadRequest(ErrorKind::InvalidParam, "User does not exist."));
|
||||
}
|
||||
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xff);
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
self.userid_devicelistversion
|
||||
.increment(user_id.as_bytes())?;
|
||||
self.userid_devicelistversion.increment(user_id.as_bytes())?;
|
||||
|
||||
self.userdeviceid_metadata.insert(
|
||||
&userdeviceid,
|
||||
|
@ -239,7 +211,7 @@ impl service::users::Data for KeyValueDatabase {
|
|||
/// Removes a device from a user.
|
||||
fn remove_device(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xff);
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
// Remove tokens
|
||||
|
@ -250,7 +222,7 @@ impl service::users::Data for KeyValueDatabase {
|
|||
|
||||
// Remove todevice events
|
||||
let mut prefix = userdeviceid.clone();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (key, _) in self.todeviceid_events.scan_prefix(prefix) {
|
||||
self.todeviceid_events.remove(&key)?;
|
||||
|
@ -258,8 +230,7 @@ impl service::users::Data for KeyValueDatabase {
|
|||
|
||||
// TODO: Remove onetimekeys
|
||||
|
||||
self.userid_devicelistversion
|
||||
.increment(user_id.as_bytes())?;
|
||||
self.userid_devicelistversion.increment(user_id.as_bytes())?;
|
||||
|
||||
self.userdeviceid_metadata.remove(&userdeviceid)?;
|
||||
|
||||
|
@ -267,39 +238,34 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
/// Returns an iterator over all device ids of this user.
|
||||
fn all_device_ids<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
|
||||
fn all_device_ids<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedDeviceId>> + 'a> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
// All devices have metadata
|
||||
Box::new(
|
||||
self.userdeviceid_metadata
|
||||
.scan_prefix(prefix)
|
||||
.map(|(bytes, _)| {
|
||||
Box::new(self.userdeviceid_metadata.scan_prefix(prefix).map(|(bytes, _)| {
|
||||
Ok(utils::string_from_bytes(
|
||||
bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
|
||||
Error::bad_database("UserDevice ID in db is invalid.")
|
||||
})?,
|
||||
bytes
|
||||
.rsplit(|&b| b == 0xFF)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("UserDevice ID in db is invalid."))?,
|
||||
)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("Device ID in userdeviceid_metadata is invalid.")
|
||||
})?
|
||||
.map_err(|_| Error::bad_database("Device ID in userdeviceid_metadata is invalid."))?
|
||||
.into())
|
||||
}),
|
||||
)
|
||||
}))
|
||||
}
|
||||
|
||||
/// Replaces the access token of one device.
|
||||
fn set_token(&self, user_id: &UserId, device_id: &DeviceId, token: &str) -> Result<()> {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xff);
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
// should not be None, but we shouldn't assert either lol...
|
||||
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
|
||||
warn!("Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", user_id, device_id);
|
||||
warn!(
|
||||
"Called set_token for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database",
|
||||
user_id, device_id
|
||||
);
|
||||
return Err(Error::bad_database(
|
||||
"User does not exist or device ID has no metadata in database.",
|
||||
));
|
||||
|
@ -312,41 +278,39 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
// Assign token to user device combination
|
||||
self.userdeviceid_token
|
||||
.insert(&userdeviceid, token.as_bytes())?;
|
||||
self.token_userdeviceid
|
||||
.insert(token.as_bytes(), &userdeviceid)?;
|
||||
self.userdeviceid_token.insert(&userdeviceid, token.as_bytes())?;
|
||||
self.token_userdeviceid.insert(token.as_bytes(), &userdeviceid)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn add_one_time_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
one_time_key_key: &DeviceKeyId,
|
||||
&self, user_id: &UserId, device_id: &DeviceId, one_time_key_key: &DeviceKeyId,
|
||||
one_time_key_value: &Raw<OneTimeKey>,
|
||||
) -> Result<()> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
// All devices have metadata
|
||||
// Only existing devices should be able to call this, but we shouldn't assert either...
|
||||
// Only existing devices should be able to call this, but we shouldn't assert
|
||||
// either...
|
||||
if self.userdeviceid_metadata.get(&key)?.is_none() {
|
||||
warn!("Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", user_id, device_id);
|
||||
warn!(
|
||||
"Called add_one_time_key for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in \
|
||||
database",
|
||||
user_id, device_id
|
||||
);
|
||||
return Err(Error::bad_database(
|
||||
"User does not exist or device ID has no metadata in database.",
|
||||
));
|
||||
}
|
||||
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
// TODO: Use DeviceKeyId::to_string when it's available (and update everything,
|
||||
// because there are no wrapping quotation marks anymore)
|
||||
key.extend_from_slice(
|
||||
serde_json::to_string(one_time_key_key)
|
||||
.expect("DeviceKeyId::to_string always works")
|
||||
.as_bytes(),
|
||||
serde_json::to_string(one_time_key_key).expect("DeviceKeyId::to_string always works").as_bytes(),
|
||||
);
|
||||
|
||||
self.onetimekeyid_onetimekeys.insert(
|
||||
|
@ -354,10 +318,7 @@ impl service::users::Data for KeyValueDatabase {
|
|||
&serde_json::to_vec(&one_time_key_value).expect("OneTimeKey::to_vec always works"),
|
||||
)?;
|
||||
|
||||
self.userid_lastonetimekeyupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
self.userid_lastonetimekeyupdate.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -366,31 +327,24 @@ impl service::users::Data for KeyValueDatabase {
|
|||
self.userid_lastonetimekeyupdate
|
||||
.get(user_id.as_bytes())?
|
||||
.map(|bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("Count in roomid_lastroomactiveupdate is invalid.")
|
||||
})
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Count in roomid_lastroomactiveupdate is invalid."))
|
||||
})
|
||||
.unwrap_or(Ok(0))
|
||||
}
|
||||
|
||||
fn take_one_time_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
key_algorithm: &DeviceKeyAlgorithm,
|
||||
&self, user_id: &UserId, device_id: &DeviceId, key_algorithm: &DeviceKeyAlgorithm,
|
||||
) -> Result<Option<(OwnedDeviceKeyId, Raw<OneTimeKey>)>> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.push(b'"'); // Annoying quotation mark
|
||||
prefix.extend_from_slice(key_algorithm.as_ref().as_bytes());
|
||||
prefix.push(b':');
|
||||
|
||||
self.userid_lastonetimekeyupdate.insert(
|
||||
user_id.as_bytes(),
|
||||
&services().globals.next_count()?.to_be_bytes(),
|
||||
)?;
|
||||
self.userid_lastonetimekeyupdate.insert(user_id.as_bytes(), &services().globals.next_count()?.to_be_bytes())?;
|
||||
|
||||
self.onetimekeyid_onetimekeys
|
||||
.scan_prefix(prefix)
|
||||
|
@ -400,7 +354,7 @@ impl service::users::Data for KeyValueDatabase {
|
|||
|
||||
Ok((
|
||||
serde_json::from_slice(
|
||||
key.rsplit(|&b| b == 0xff)
|
||||
key.rsplit(|&b| b == 0xFF)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("OneTimeKeyId in db is invalid."))?,
|
||||
)
|
||||
|
@ -413,45 +367,35 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn count_one_time_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
&self, user_id: &UserId, device_id: &DeviceId,
|
||||
) -> Result<BTreeMap<DeviceKeyAlgorithm, UInt>> {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xff);
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
let mut counts = BTreeMap::new();
|
||||
|
||||
for algorithm in
|
||||
self.onetimekeyid_onetimekeys
|
||||
.scan_prefix(userdeviceid)
|
||||
.map(|(bytes, _)| {
|
||||
for algorithm in self.onetimekeyid_onetimekeys.scan_prefix(userdeviceid).map(|(bytes, _)| {
|
||||
Ok::<_, Error>(
|
||||
serde_json::from_slice::<OwnedDeviceKeyId>(
|
||||
bytes.rsplit(|&b| b == 0xff).next().ok_or_else(|| {
|
||||
Error::bad_database("OneTimeKey ID in db is invalid.")
|
||||
})?,
|
||||
bytes
|
||||
.rsplit(|&b| b == 0xFF)
|
||||
.next()
|
||||
.ok_or_else(|| Error::bad_database("OneTimeKey ID in db is invalid."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("DeviceKeyId in db is invalid."))?
|
||||
.algorithm(),
|
||||
)
|
||||
})
|
||||
{
|
||||
}) {
|
||||
*counts.entry(algorithm?).or_default() += UInt::from(1_u32);
|
||||
}
|
||||
|
||||
Ok(counts)
|
||||
}
|
||||
|
||||
fn add_device_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
device_keys: &Raw<DeviceKeys>,
|
||||
) -> Result<()> {
|
||||
fn add_device_keys(&self, user_id: &UserId, device_id: &DeviceId, device_keys: &Raw<DeviceKeys>) -> Result<()> {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xff);
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
self.keyid_key.insert(
|
||||
|
@ -465,39 +409,30 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn add_cross_signing_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
master_key: &Raw<CrossSigningKey>,
|
||||
self_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||
user_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||
notify: bool,
|
||||
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>, self_signing_key: &Option<Raw<CrossSigningKey>>,
|
||||
user_signing_key: &Option<Raw<CrossSigningKey>>, notify: bool,
|
||||
) -> Result<()> {
|
||||
// TODO: Check signatures
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let (master_key_key, _) = self.parse_master_key(user_id, master_key)?;
|
||||
|
||||
self.keyid_key
|
||||
.insert(&master_key_key, master_key.json().get().as_bytes())?;
|
||||
self.keyid_key.insert(&master_key_key, master_key.json().get().as_bytes())?;
|
||||
|
||||
self.userid_masterkeyid
|
||||
.insert(user_id.as_bytes(), &master_key_key)?;
|
||||
self.userid_masterkeyid.insert(user_id.as_bytes(), &master_key_key)?;
|
||||
|
||||
// Self-signing key
|
||||
if let Some(self_signing_key) = self_signing_key {
|
||||
let mut self_signing_key_ids = self_signing_key
|
||||
.deserialize()
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key")
|
||||
})?
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid self signing key"))?
|
||||
.keys
|
||||
.into_values();
|
||||
|
||||
let self_signing_key_id = self_signing_key_ids.next().ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Self signing key contained no key.",
|
||||
))?;
|
||||
let self_signing_key_id = self_signing_key_ids
|
||||
.next()
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Self signing key contained no key."))?;
|
||||
|
||||
if self_signing_key_ids.next().is_some() {
|
||||
return Err(Error::BadRequest(
|
||||
|
@ -509,29 +444,22 @@ impl service::users::Data for KeyValueDatabase {
|
|||
let mut self_signing_key_key = prefix.clone();
|
||||
self_signing_key_key.extend_from_slice(self_signing_key_id.as_bytes());
|
||||
|
||||
self.keyid_key.insert(
|
||||
&self_signing_key_key,
|
||||
self_signing_key.json().get().as_bytes(),
|
||||
)?;
|
||||
self.keyid_key.insert(&self_signing_key_key, self_signing_key.json().get().as_bytes())?;
|
||||
|
||||
self.userid_selfsigningkeyid
|
||||
.insert(user_id.as_bytes(), &self_signing_key_key)?;
|
||||
self.userid_selfsigningkeyid.insert(user_id.as_bytes(), &self_signing_key_key)?;
|
||||
}
|
||||
|
||||
// User-signing key
|
||||
if let Some(user_signing_key) = user_signing_key {
|
||||
let mut user_signing_key_ids = user_signing_key
|
||||
.deserialize()
|
||||
.map_err(|_| {
|
||||
Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key")
|
||||
})?
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid user signing key"))?
|
||||
.keys
|
||||
.into_values();
|
||||
|
||||
let user_signing_key_id = user_signing_key_ids.next().ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"User signing key contained no key.",
|
||||
))?;
|
||||
let user_signing_key_id = user_signing_key_ids
|
||||
.next()
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "User signing key contained no key."))?;
|
||||
|
||||
if user_signing_key_ids.next().is_some() {
|
||||
return Err(Error::BadRequest(
|
||||
|
@ -543,13 +471,9 @@ impl service::users::Data for KeyValueDatabase {
|
|||
let mut user_signing_key_key = prefix;
|
||||
user_signing_key_key.extend_from_slice(user_signing_key_id.as_bytes());
|
||||
|
||||
self.keyid_key.insert(
|
||||
&user_signing_key_key,
|
||||
user_signing_key.json().get().as_bytes(),
|
||||
)?;
|
||||
self.keyid_key.insert(&user_signing_key_key, user_signing_key.json().get().as_bytes())?;
|
||||
|
||||
self.userid_usersigningkeyid
|
||||
.insert(user_id.as_bytes(), &user_signing_key_key)?;
|
||||
self.userid_usersigningkeyid.insert(user_id.as_bytes(), &user_signing_key_key)?;
|
||||
}
|
||||
|
||||
if notify {
|
||||
|
@ -560,21 +484,18 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn sign_key(
|
||||
&self,
|
||||
target_id: &UserId,
|
||||
key_id: &str,
|
||||
signature: (String, String),
|
||||
sender_id: &UserId,
|
||||
&self, target_id: &UserId, key_id: &str, signature: (String, String), sender_id: &UserId,
|
||||
) -> Result<()> {
|
||||
let mut key = target_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(key_id.as_bytes());
|
||||
|
||||
let mut cross_signing_key: serde_json::Value =
|
||||
serde_json::from_slice(&self.keyid_key.get(&key)?.ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Tried to sign nonexistent key.",
|
||||
))?)
|
||||
let mut cross_signing_key: serde_json::Value = serde_json::from_slice(
|
||||
&self
|
||||
.keyid_key
|
||||
.get(&key)?
|
||||
.ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Tried to sign nonexistent key."))?,
|
||||
)
|
||||
.map_err(|_| Error::bad_database("key in keyid_key is invalid."))?;
|
||||
|
||||
let signatures = cross_signing_key
|
||||
|
@ -601,13 +522,10 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn keys_changed<'a>(
|
||||
&'a self,
|
||||
user_or_room_id: &str,
|
||||
from: u64,
|
||||
to: Option<u64>,
|
||||
&'a self, user_or_room_id: &str, from: u64, to: Option<u64>,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
|
||||
let mut prefix = user_or_room_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut start = prefix.clone();
|
||||
start.extend_from_slice(&(from + 1).to_be_bytes());
|
||||
|
@ -619,7 +537,7 @@ impl service::users::Data for KeyValueDatabase {
|
|||
.iter_from(&start, false)
|
||||
.take_while(move |(k, _)| {
|
||||
k.starts_with(&prefix)
|
||||
&& if let Some(current) = k.splitn(2, |&b| b == 0xff).nth(1) {
|
||||
&& if let Some(current) = k.splitn(2, |&b| b == 0xFF).nth(1) {
|
||||
if let Ok(c) = utils::u64_from_bytes(current) {
|
||||
c <= to
|
||||
} else {
|
||||
|
@ -632,83 +550,63 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
})
|
||||
.map(|(_, bytes)| {
|
||||
UserId::parse(utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database(
|
||||
"User ID in devicekeychangeid_userid is invalid unicode.",
|
||||
UserId::parse(
|
||||
utils::string_from_bytes(&bytes).map_err(|_| {
|
||||
Error::bad_database("User ID in devicekeychangeid_userid is invalid unicode.")
|
||||
})?,
|
||||
)
|
||||
})?)
|
||||
.map_err(|_| {
|
||||
Error::bad_database("User ID in devicekeychangeid_userid is invalid.")
|
||||
})
|
||||
.map_err(|_| Error::bad_database("User ID in devicekeychangeid_userid is invalid."))
|
||||
}),
|
||||
)
|
||||
}
|
||||
|
||||
fn mark_device_key_update(&self, user_id: &UserId) -> Result<()> {
|
||||
let count = services().globals.next_count()?.to_be_bytes();
|
||||
for room_id in services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.rooms_joined(user_id)
|
||||
.filter_map(std::result::Result::ok)
|
||||
{
|
||||
for room_id in services().rooms.state_cache.rooms_joined(user_id).filter_map(std::result::Result::ok) {
|
||||
// Don't send key updates to unencrypted rooms
|
||||
if services()
|
||||
.rooms
|
||||
.state_accessor
|
||||
.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?
|
||||
.is_none()
|
||||
if services().rooms.state_accessor.room_state_get(&room_id, &StateEventType::RoomEncryption, "")?.is_none()
|
||||
{
|
||||
continue;
|
||||
}
|
||||
|
||||
let mut key = room_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&count);
|
||||
|
||||
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
|
||||
}
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&count);
|
||||
self.keychangeid_userid.insert(&key, user_id.as_bytes())?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn get_device_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Result<Option<Raw<DeviceKeys>>> {
|
||||
fn get_device_keys(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Raw<DeviceKeys>>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database("DeviceKeys in db are invalid.")
|
||||
})?))
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&bytes).map_err(|_| Error::bad_database("DeviceKeys in db are invalid."))?,
|
||||
))
|
||||
})
|
||||
}
|
||||
|
||||
fn parse_master_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
master_key: &Raw<CrossSigningKey>,
|
||||
&self, user_id: &UserId, master_key: &Raw<CrossSigningKey>,
|
||||
) -> Result<(Vec<u8>, CrossSigningKey)> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let master_key = master_key
|
||||
.deserialize()
|
||||
.map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
|
||||
let master_key =
|
||||
master_key.deserialize().map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid master key"))?;
|
||||
let mut master_key_ids = master_key.keys.values();
|
||||
let master_key_id = master_key_ids.next().ok_or(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
"Master key contained no key.",
|
||||
))?;
|
||||
let master_key_id =
|
||||
master_key_ids.next().ok_or(Error::BadRequest(ErrorKind::InvalidParam, "Master key contained no key."))?;
|
||||
if master_key_ids.next().is_some() {
|
||||
return Err(Error::BadRequest(
|
||||
ErrorKind::InvalidParam,
|
||||
|
@ -721,79 +619,54 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn get_key(
|
||||
&self,
|
||||
key: &[u8],
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||
&self, key: &[u8], sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||
) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||
self.keyid_key.get(key)?.map_or(Ok(None), |bytes| {
|
||||
let mut cross_signing_key = serde_json::from_slice::<serde_json::Value>(&bytes)
|
||||
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?;
|
||||
clean_signatures(
|
||||
&mut cross_signing_key,
|
||||
sender_user,
|
||||
user_id,
|
||||
allowed_signatures,
|
||||
)?;
|
||||
clean_signatures(&mut cross_signing_key, sender_user, user_id, allowed_signatures)?;
|
||||
|
||||
Ok(Some(Raw::from_json(
|
||||
serde_json::value::to_raw_value(&cross_signing_key)
|
||||
.expect("Value to RawValue serialization"),
|
||||
serde_json::value::to_raw_value(&cross_signing_key).expect("Value to RawValue serialization"),
|
||||
)))
|
||||
})
|
||||
}
|
||||
|
||||
fn get_master_key(
|
||||
&self,
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||
) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||
self.userid_masterkeyid
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |key| {
|
||||
self.get_key(&key, sender_user, user_id, allowed_signatures)
|
||||
})
|
||||
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
|
||||
}
|
||||
|
||||
fn get_self_signing_key(
|
||||
&self,
|
||||
sender_user: Option<&UserId>,
|
||||
user_id: &UserId,
|
||||
allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||
&self, sender_user: Option<&UserId>, user_id: &UserId, allowed_signatures: &dyn Fn(&UserId) -> bool,
|
||||
) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||
self.userid_selfsigningkeyid
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |key| {
|
||||
self.get_key(&key, sender_user, user_id, allowed_signatures)
|
||||
})
|
||||
.map_or(Ok(None), |key| self.get_key(&key, sender_user, user_id, allowed_signatures))
|
||||
}
|
||||
|
||||
fn get_user_signing_key(&self, user_id: &UserId) -> Result<Option<Raw<CrossSigningKey>>> {
|
||||
self.userid_usersigningkeyid
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |key| {
|
||||
self.userid_usersigningkeyid.get(user_id.as_bytes())?.map_or(Ok(None), |key| {
|
||||
self.keyid_key.get(&key)?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database("CrossSigningKey in db is invalid.")
|
||||
})?))
|
||||
Ok(Some(
|
||||
serde_json::from_slice(&bytes)
|
||||
.map_err(|_| Error::bad_database("CrossSigningKey in db is invalid."))?,
|
||||
))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
fn add_to_device_event(
|
||||
&self,
|
||||
sender: &UserId,
|
||||
target_user_id: &UserId,
|
||||
target_device_id: &DeviceId,
|
||||
event_type: &str,
|
||||
&self, sender: &UserId, target_user_id: &UserId, target_device_id: &DeviceId, event_type: &str,
|
||||
content: serde_json::Value,
|
||||
) -> Result<()> {
|
||||
let mut key = target_user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(target_device_id.as_bytes());
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&services().globals.next_count()?.to_be_bytes());
|
||||
|
||||
let mut json = serde_json::Map::new();
|
||||
|
@ -808,17 +681,13 @@ impl service::users::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn get_to_device_events(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Result<Vec<Raw<AnyToDeviceEvent>>> {
|
||||
fn get_to_device_events(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Vec<Raw<AnyToDeviceEvent>>> {
|
||||
let mut events = Vec::new();
|
||||
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
for (_, value) in self.todeviceid_events.scan_prefix(prefix) {
|
||||
events.push(
|
||||
|
@ -830,16 +699,11 @@ impl service::users::Data for KeyValueDatabase {
|
|||
Ok(events)
|
||||
}
|
||||
|
||||
fn remove_to_device_events(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
until: u64,
|
||||
) -> Result<()> {
|
||||
fn remove_to_device_events(&self, user_id: &UserId, device_id: &DeviceId, until: u64) -> Result<()> {
|
||||
let mut prefix = user_id.as_bytes().to_vec();
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
prefix.extend_from_slice(device_id.as_bytes());
|
||||
prefix.push(0xff);
|
||||
prefix.push(0xFF);
|
||||
|
||||
let mut last = prefix.clone();
|
||||
last.extend_from_slice(&until.to_be_bytes());
|
||||
|
@ -864,26 +728,25 @@ impl service::users::Data for KeyValueDatabase {
|
|||
Ok(())
|
||||
}
|
||||
|
||||
fn update_device_metadata(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
device: &Device,
|
||||
) -> Result<()> {
|
||||
fn update_device_metadata(&self, user_id: &UserId, device_id: &DeviceId, device: &Device) -> Result<()> {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xff);
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
// Only existing devices should be able to call this, but we shouldn't assert either...
|
||||
// Only existing devices should be able to call this, but we shouldn't assert
|
||||
// either...
|
||||
if self.userdeviceid_metadata.get(&userdeviceid)?.is_none() {
|
||||
warn!("Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no metadata in database", user_id, device_id);
|
||||
warn!(
|
||||
"Called update_device_metadata for a non-existent user \"{}\" and/or device ID \"{}\" with no \
|
||||
metadata in database",
|
||||
user_id, device_id
|
||||
);
|
||||
return Err(Error::bad_database(
|
||||
"User does not exist or device ID has no metadata in database.",
|
||||
));
|
||||
}
|
||||
|
||||
self.userid_devicelistversion
|
||||
.increment(user_id.as_bytes())?;
|
||||
self.userid_devicelistversion.increment(user_id.as_bytes())?;
|
||||
|
||||
self.userdeviceid_metadata.insert(
|
||||
&userdeviceid,
|
||||
|
@ -894,18 +757,12 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
/// Get device metadata.
|
||||
fn get_device_metadata(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
device_id: &DeviceId,
|
||||
) -> Result<Option<Device>> {
|
||||
fn get_device_metadata(&self, user_id: &UserId, device_id: &DeviceId) -> Result<Option<Device>> {
|
||||
let mut userdeviceid = user_id.as_bytes().to_vec();
|
||||
userdeviceid.push(0xff);
|
||||
userdeviceid.push(0xFF);
|
||||
userdeviceid.extend_from_slice(device_id.as_bytes());
|
||||
|
||||
self.userdeviceid_metadata
|
||||
.get(&userdeviceid)?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
self.userdeviceid_metadata.get(&userdeviceid)?.map_or(Ok(None), |bytes| {
|
||||
Ok(Some(serde_json::from_slice(&bytes).map_err(|_| {
|
||||
Error::bad_database("Metadata in userdeviceid_metadata is invalid.")
|
||||
})?))
|
||||
|
@ -913,31 +770,19 @@ impl service::users::Data for KeyValueDatabase {
|
|||
}
|
||||
|
||||
fn get_devicelist_version(&self, user_id: &UserId) -> Result<Option<u64>> {
|
||||
self.userid_devicelistversion
|
||||
.get(user_id.as_bytes())?
|
||||
.map_or(Ok(None), |bytes| {
|
||||
utils::u64_from_bytes(&bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid devicelistversion in db."))
|
||||
.map(Some)
|
||||
self.userid_devicelistversion.get(user_id.as_bytes())?.map_or(Ok(None), |bytes| {
|
||||
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid devicelistversion in db.")).map(Some)
|
||||
})
|
||||
}
|
||||
|
||||
fn all_devices_metadata<'a>(
|
||||
&'a self,
|
||||
user_id: &UserId,
|
||||
) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
|
||||
fn all_devices_metadata<'a>(&'a self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<Device>> + 'a> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
|
||||
Box::new(
|
||||
self.userdeviceid_metadata
|
||||
.scan_prefix(key)
|
||||
.map(|(_, bytes)| {
|
||||
serde_json::from_slice::<Device>(&bytes).map_err(|_| {
|
||||
Error::bad_database("Device in userdeviceid_metadata is invalid.")
|
||||
})
|
||||
}),
|
||||
)
|
||||
Box::new(self.userdeviceid_metadata.scan_prefix(key).map(|(_, bytes)| {
|
||||
serde_json::from_slice::<Device>(&bytes)
|
||||
.map_err(|_| Error::bad_database("Device in userdeviceid_metadata is invalid."))
|
||||
}))
|
||||
}
|
||||
|
||||
/// Creates a new sync filter. Returns the filter id.
|
||||
|
@ -945,27 +790,23 @@ impl service::users::Data for KeyValueDatabase {
|
|||
let filter_id = utils::random_string(4);
|
||||
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(filter_id.as_bytes());
|
||||
|
||||
self.userfilterid_filter.insert(
|
||||
&key,
|
||||
&serde_json::to_vec(&filter).expect("filter is valid json"),
|
||||
)?;
|
||||
self.userfilterid_filter.insert(&key, &serde_json::to_vec(&filter).expect("filter is valid json"))?;
|
||||
|
||||
Ok(filter_id)
|
||||
}
|
||||
|
||||
fn get_filter(&self, user_id: &UserId, filter_id: &str) -> Result<Option<FilterDefinition>> {
|
||||
let mut key = user_id.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(filter_id.as_bytes());
|
||||
|
||||
let raw = self.userfilterid_filter.get(&key)?;
|
||||
|
||||
if let Some(raw) = raw {
|
||||
serde_json::from_slice(&raw)
|
||||
.map_err(|_| Error::bad_database("Invalid filter event in db."))
|
||||
serde_json::from_slice(&raw).map_err(|_| Error::bad_database("Invalid filter event in db."))
|
||||
} else {
|
||||
Ok(None)
|
||||
}
|
||||
|
@ -976,8 +817,8 @@ impl KeyValueDatabase {}
|
|||
|
||||
/// Will only return with Some(username) if the password was not empty and the
|
||||
/// username could be successfully parsed.
|
||||
/// If utils::string_from_bytes(...) returns an error that username will be skipped
|
||||
/// and the error will be logged.
|
||||
/// If utils::string_from_bytes(...) returns an error that username will be
|
||||
/// skipped and the error will be logged.
|
||||
fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<String> {
|
||||
// A valid password is not empty
|
||||
if password.is_empty() {
|
||||
|
@ -986,12 +827,9 @@ fn get_username_with_valid_password(username: &[u8], password: &[u8]) -> Option<
|
|||
match utils::string_from_bytes(username) {
|
||||
Ok(u) => Some(u),
|
||||
Err(e) => {
|
||||
warn!(
|
||||
"Failed to parse username while calling get_local_users(): {}",
|
||||
e.to_string()
|
||||
);
|
||||
warn!("Failed to parse username while calling get_local_users(): {}", e.to_string());
|
||||
None
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,10 +1,16 @@
|
|||
pub(crate) mod abstraction;
|
||||
pub(crate) mod key_value;
|
||||
|
||||
use crate::{
|
||||
service::rooms::{edus::presence::presence_handler, timeline::PduCount},
|
||||
services, utils, Config, Error, PduEvent, Result, Services, SERVICES,
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
fs::{self},
|
||||
io::Write,
|
||||
mem::size_of,
|
||||
path::Path,
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use abstraction::{KeyValueDatabaseEngine, KvTree};
|
||||
use argon2::{password_hash::SaltString, PasswordHasher, PasswordVerifier};
|
||||
use itertools::Itertools;
|
||||
|
@ -18,23 +24,17 @@ use ruma::{
|
|||
GlobalAccountDataEvent, GlobalAccountDataEventType, StateEventType,
|
||||
},
|
||||
push::Ruleset,
|
||||
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId,
|
||||
UserId,
|
||||
CanonicalJsonValue, EventId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
use serde::Deserialize;
|
||||
use std::{
|
||||
collections::{BTreeMap, HashMap, HashSet},
|
||||
fs::{self},
|
||||
io::Write,
|
||||
mem::size_of,
|
||||
path::Path,
|
||||
sync::{Arc, Mutex, RwLock},
|
||||
time::Duration,
|
||||
};
|
||||
use tokio::{sync::mpsc, time::interval};
|
||||
|
||||
use tracing::{debug, error, info, warn};
|
||||
|
||||
use crate::{
|
||||
service::rooms::{edus::presence::presence_handler, timeline::PduCount},
|
||||
services, utils, Config, Error, PduEvent, Result, Services, SERVICES,
|
||||
};
|
||||
|
||||
pub struct KeyValueDatabase {
|
||||
db: Arc<dyn KeyValueDatabaseEngine>,
|
||||
|
||||
|
@ -128,12 +128,15 @@ pub struct KeyValueDatabase {
|
|||
pub(super) eventid_shorteventid: Arc<dyn KvTree>,
|
||||
|
||||
pub(super) statehash_shortstatehash: Arc<dyn KvTree>,
|
||||
pub(super) shortstatehash_statediff: Arc<dyn KvTree>, // StateDiff = parent (or 0) + (shortstatekey+shorteventid++) + 0_u64 + (shortstatekey+shorteventid--)
|
||||
pub(super) shortstatehash_statediff: Arc<dyn KvTree>, /* StateDiff = parent (or 0) +
|
||||
* (shortstatekey+shorteventid++) + 0_u64 +
|
||||
* (shortstatekey+shorteventid--) */
|
||||
|
||||
pub(super) shorteventid_authchain: Arc<dyn KvTree>,
|
||||
|
||||
/// RoomId + EventId -> outlier PDU.
|
||||
/// Any pdu that has passed the steps 1-8 in the incoming event /federation/send/txn.
|
||||
/// Any pdu that has passed the steps 1-8 in the incoming event
|
||||
/// /federation/send/txn.
|
||||
pub(super) eventid_outlierpdu: Arc<dyn KvTree>,
|
||||
pub(super) softfailedeventids: Arc<dyn KvTree>,
|
||||
|
||||
|
@ -155,11 +158,14 @@ pub struct KeyValueDatabase {
|
|||
pub(super) backupkeyid_backup: Arc<dyn KvTree>, // BackupKeyId = UserId + Version + RoomId + SessionId
|
||||
|
||||
//pub transaction_ids: transaction_ids::TransactionIds,
|
||||
pub(super) userdevicetxnid_response: Arc<dyn KvTree>, // Response can be empty (/sendToDevice) or the event id (/send)
|
||||
pub(super) userdevicetxnid_response: Arc<dyn KvTree>, /* Response can be empty (/sendToDevice) or the event id
|
||||
* (/send) */
|
||||
//pub sending: sending::Sending,
|
||||
pub(super) servername_educount: Arc<dyn KvTree>, // EduCount: Count of last EDU sync
|
||||
pub(super) servernameevent_data: Arc<dyn KvTree>, // ServernameEvent = (+ / $)SenderKey / ServerName / UserId + PduId / Id (for edus), Data = EDU content
|
||||
pub(super) servercurrentevent_data: Arc<dyn KvTree>, // ServerCurrentEvents = (+ / $)ServerName / UserId + PduId / Id (for edus), Data = EDU content
|
||||
pub(super) servernameevent_data: Arc<dyn KvTree>, /* ServernameEvent = (+ / $)SenderKey / ServerName / UserId +
|
||||
* PduId / Id (for edus), Data = EDU content */
|
||||
pub(super) servercurrentevent_data: Arc<dyn KvTree>, /* ServerCurrentEvents = (+ / $)ServerName / UserId + PduId
|
||||
* / Id (for edus), Data = EDU content */
|
||||
|
||||
//pub appservice: appservice::Appservice,
|
||||
pub(super) id_appserviceregistrations: Arc<dyn KvTree>,
|
||||
|
@ -223,10 +229,14 @@ impl KeyValueDatabase {
|
|||
|
||||
if !Path::new(&config.database_path).exists() {
|
||||
debug!("Database path does not exist, assuming this is a new setup and creating it");
|
||||
std::fs::create_dir_all(&config.database_path)
|
||||
.map_err(|e| {
|
||||
std::fs::create_dir_all(&config.database_path).map_err(|e| {
|
||||
error!("Failed to create database path: {e}");
|
||||
Error::BadConfig("Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please create the database folder yourself or allow conduwuit the permissions to create directories and files.")})?;
|
||||
Error::BadConfig(
|
||||
"Database folder doesn't exists and couldn't be created (e.g. due to missing permissions). Please \
|
||||
create the database folder yourself or allow conduwuit the permissions to create directories and \
|
||||
files.",
|
||||
)
|
||||
})?;
|
||||
}
|
||||
|
||||
let builder: Arc<dyn KeyValueDatabaseEngine> = match &*config.database_backend {
|
||||
|
@ -236,17 +246,19 @@ impl KeyValueDatabase {
|
|||
return Err(Error::BadConfig("Database backend not found."));
|
||||
#[cfg(feature = "sqlite")]
|
||||
Arc::new(Arc::<abstraction::sqlite::Engine>::open(&config)?)
|
||||
}
|
||||
},
|
||||
"rocksdb" => {
|
||||
debug!("Got rocksdb database backend");
|
||||
#[cfg(not(feature = "rocksdb"))]
|
||||
return Err(Error::BadConfig("Database backend not found."));
|
||||
#[cfg(feature = "rocksdb")]
|
||||
Arc::new(Arc::<abstraction::rocksdb::Engine>::open(&config)?)
|
||||
}
|
||||
},
|
||||
_ => {
|
||||
return Err(Error::BadConfig("Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends."));
|
||||
}
|
||||
return Err(Error::BadConfig(
|
||||
"Database backend not found. sqlite (not recommended) and rocksdb are the only supported backends.",
|
||||
));
|
||||
},
|
||||
};
|
||||
|
||||
let (presence_sender, presence_receiver) = mpsc::unbounded_channel();
|
||||
|
@ -275,8 +287,7 @@ impl KeyValueDatabase {
|
|||
userdevicesessionid_uiaarequest: RwLock::new(BTreeMap::new()),
|
||||
readreceiptid_readreceipt: builder.open_tree("readreceiptid_readreceipt")?,
|
||||
roomuserid_privateread: builder.open_tree("roomuserid_privateread")?, // "Private" read receipt
|
||||
roomuserid_lastprivatereadupdate: builder
|
||||
.open_tree("roomuserid_lastprivatereadupdate")?,
|
||||
roomuserid_lastprivatereadupdate: builder.open_tree("roomuserid_lastprivatereadupdate")?,
|
||||
typingid_userid: builder.open_tree("typingid_userid")?,
|
||||
roomid_lasttypingupdate: builder.open_tree("roomid_lasttypingupdate")?,
|
||||
roomuserid_presence: builder.open_tree("roomuserid_presence")?,
|
||||
|
@ -352,14 +363,9 @@ impl KeyValueDatabase {
|
|||
|
||||
cached_registrations: Arc::new(RwLock::new(HashMap::new())),
|
||||
pdu_cache: Mutex::new(LruCache::new(
|
||||
config
|
||||
.pdu_cache_capacity
|
||||
.try_into()
|
||||
.expect("pdu cache capacity fits into usize"),
|
||||
)),
|
||||
auth_chain_cache: Mutex::new(LruCache::new(
|
||||
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
config.pdu_cache_capacity.try_into().expect("pdu cache capacity fits into usize"),
|
||||
)),
|
||||
auth_chain_cache: Mutex::new(LruCache::new((100_000.0 * config.conduit_cache_capacity_modifier) as usize)),
|
||||
shorteventid_cache: Mutex::new(LruCache::new(
|
||||
(100_000.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
|
@ -388,17 +394,14 @@ impl KeyValueDatabase {
|
|||
// Matrix resource ownership is based on the server name; changing it
|
||||
// requires recreating the database from scratch.
|
||||
if services().users.count()? > 0 {
|
||||
let conduit_user =
|
||||
UserId::parse_with_server_name("conduit", services().globals.server_name())
|
||||
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
|
||||
.expect("@conduit:server_name is valid");
|
||||
|
||||
if !services().users.exists(&conduit_user)? {
|
||||
error!(
|
||||
"The {} server user does not exist, and the database is not new.",
|
||||
conduit_user
|
||||
);
|
||||
error!("The {} server user does not exist, and the database is not new.", conduit_user);
|
||||
return Err(Error::bad_database(
|
||||
"Cannot reuse an existing database after changing the server name, please delete the old one first."
|
||||
"Cannot reuse an existing database after changing the server name, please delete the old one \
|
||||
first.",
|
||||
));
|
||||
}
|
||||
}
|
||||
|
@ -415,17 +418,17 @@ impl KeyValueDatabase {
|
|||
// MIGRATIONS
|
||||
if services().globals.database_version()? < 1 {
|
||||
for (roomserverid, _) in db.roomserverids.iter() {
|
||||
let mut parts = roomserverid.split(|&b| b == 0xff);
|
||||
let mut parts = roomserverid.split(|&b| b == 0xFF);
|
||||
let room_id = parts.next().expect("split always returns one element");
|
||||
let servername = match parts.next() {
|
||||
Some(s) => s,
|
||||
None => {
|
||||
error!("Migration: Invalid roomserverid in db.");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
};
|
||||
let mut serverroomid = servername.to_vec();
|
||||
serverroomid.push(0xff);
|
||||
serverroomid.push(0xFF);
|
||||
serverroomid.extend_from_slice(room_id);
|
||||
|
||||
db.serverroomids.insert(&serverroomid, &[])?;
|
||||
|
@ -445,11 +448,8 @@ impl KeyValueDatabase {
|
|||
.argon
|
||||
.hash_password(b"", &salt)
|
||||
.expect("our own password to be properly hashed");
|
||||
let empty_hashed_password = services()
|
||||
.globals
|
||||
.argon
|
||||
.verify_password(&password, &empty_pass)
|
||||
.is_ok();
|
||||
let empty_hashed_password =
|
||||
services().globals.argon.verify_password(&password, &empty_pass).is_ok();
|
||||
|
||||
if empty_hashed_password {
|
||||
db.userid_password.insert(&userid, b"")?;
|
||||
|
@ -506,19 +506,18 @@ impl KeyValueDatabase {
|
|||
if services().globals.database_version()? < 5 {
|
||||
// Upgrade user data store
|
||||
for (roomuserdataid, _) in db.roomuserdataid_accountdata.iter() {
|
||||
let mut parts = roomuserdataid.split(|&b| b == 0xff);
|
||||
let mut parts = roomuserdataid.split(|&b| b == 0xFF);
|
||||
let room_id = parts.next().unwrap();
|
||||
let user_id = parts.next().unwrap();
|
||||
let event_type = roomuserdataid.rsplit(|&b| b == 0xff).next().unwrap();
|
||||
let event_type = roomuserdataid.rsplit(|&b| b == 0xFF).next().unwrap();
|
||||
|
||||
let mut key = room_id.to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(user_id);
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(event_type);
|
||||
|
||||
db.roomusertype_roomuserdataid
|
||||
.insert(&key, &roomuserdataid)?;
|
||||
db.roomusertype_roomuserdataid.insert(&key, &roomuserdataid)?;
|
||||
}
|
||||
|
||||
services().globals.bump_database_version(5)?;
|
||||
|
@ -547,8 +546,7 @@ impl KeyValueDatabase {
|
|||
let mut current_state = HashSet::new();
|
||||
let mut counter = 0;
|
||||
|
||||
let mut handle_state =
|
||||
|current_sstatehash: u64,
|
||||
let mut handle_state = |current_sstatehash: u64,
|
||||
current_room: &RoomId,
|
||||
current_state: HashSet<_>,
|
||||
last_roomstates: &mut HashMap<_, _>| {
|
||||
|
@ -558,25 +556,16 @@ impl KeyValueDatabase {
|
|||
let states_parents = last_roomsstatehash.map_or_else(
|
||||
|| Ok(Vec::new()),
|
||||
|&last_roomsstatehash| {
|
||||
services()
|
||||
.rooms
|
||||
.state_compressor
|
||||
.load_shortstatehash_info(last_roomsstatehash)
|
||||
services().rooms.state_compressor.load_shortstatehash_info(last_roomsstatehash)
|
||||
},
|
||||
)?;
|
||||
|
||||
let (statediffnew, statediffremoved) =
|
||||
if let Some(parent_stateinfo) = states_parents.last() {
|
||||
let statediffnew = current_state
|
||||
.difference(&parent_stateinfo.1)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
|
||||
let statediffnew =
|
||||
current_state.difference(&parent_stateinfo.1).copied().collect::<HashSet<_>>();
|
||||
|
||||
let statediffremoved = parent_stateinfo
|
||||
.1
|
||||
.difference(¤t_state)
|
||||
.copied()
|
||||
.collect::<HashSet<_>>();
|
||||
let statediffremoved =
|
||||
parent_stateinfo.1.difference(¤t_state).copied().collect::<HashSet<_>>();
|
||||
|
||||
(statediffnew, statediffremoved)
|
||||
} else {
|
||||
|
@ -617,8 +606,8 @@ impl KeyValueDatabase {
|
|||
};
|
||||
|
||||
for (k, seventid) in db.db.open_tree("stateid_shorteventid")?.iter() {
|
||||
let sstatehash = utils::u64_from_bytes(&k[0..size_of::<u64>()])
|
||||
.expect("number of bytes is correct");
|
||||
let sstatehash =
|
||||
utils::u64_from_bytes(&k[0..size_of::<u64>()]).expect("number of bytes is correct");
|
||||
let sstatekey = k[size_of::<u64>()..].to_vec();
|
||||
if Some(sstatehash) != current_sstatehash {
|
||||
if let Some(current_sstatehash) = current_sstatehash {
|
||||
|
@ -628,8 +617,7 @@ impl KeyValueDatabase {
|
|||
current_state,
|
||||
&mut last_roomstates,
|
||||
)?;
|
||||
last_roomstates
|
||||
.insert(current_room.clone().unwrap(), current_sstatehash);
|
||||
last_roomstates.insert(current_room.clone().unwrap(), current_sstatehash);
|
||||
}
|
||||
current_state = HashSet::new();
|
||||
current_sstatehash = Some(sstatehash);
|
||||
|
@ -637,12 +625,7 @@ impl KeyValueDatabase {
|
|||
let event_id = db.shorteventid_eventid.get(&seventid).unwrap().unwrap();
|
||||
let string = utils::string_from_bytes(&event_id).unwrap();
|
||||
let event_id = <&EventId>::try_from(string.as_str()).unwrap();
|
||||
let pdu = services()
|
||||
.rooms
|
||||
.timeline
|
||||
.get_pdu(event_id)
|
||||
.unwrap()
|
||||
.unwrap();
|
||||
let pdu = services().rooms.timeline.get_pdu(event_id).unwrap().unwrap();
|
||||
|
||||
if Some(&pdu.room_id) != current_room.as_ref() {
|
||||
current_room = Some(pdu.room_id.clone());
|
||||
|
@ -680,15 +663,11 @@ impl KeyValueDatabase {
|
|||
if !key.starts_with(b"!") {
|
||||
return None;
|
||||
}
|
||||
let mut parts = key.splitn(2, |&b| b == 0xff);
|
||||
let mut parts = key.splitn(2, |&b| b == 0xFF);
|
||||
let room_id = parts.next().unwrap();
|
||||
let count = parts.next().unwrap();
|
||||
|
||||
let short_room_id = db
|
||||
.roomid_shortroomid
|
||||
.get(room_id)
|
||||
.unwrap()
|
||||
.expect("shortroomid should exist");
|
||||
let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist");
|
||||
|
||||
let mut new_key = short_room_id;
|
||||
new_key.extend_from_slice(count);
|
||||
|
@ -702,15 +681,11 @@ impl KeyValueDatabase {
|
|||
if !value.starts_with(b"!") {
|
||||
return None;
|
||||
}
|
||||
let mut parts = value.splitn(2, |&b| b == 0xff);
|
||||
let mut parts = value.splitn(2, |&b| b == 0xFF);
|
||||
let room_id = parts.next().unwrap();
|
||||
let count = parts.next().unwrap();
|
||||
|
||||
let short_room_id = db
|
||||
.roomid_shortroomid
|
||||
.get(room_id)
|
||||
.unwrap()
|
||||
.expect("shortroomid should exist");
|
||||
let short_room_id = db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist");
|
||||
|
||||
let mut new_value = short_room_id;
|
||||
new_value.extend_from_slice(count);
|
||||
|
@ -734,20 +709,17 @@ impl KeyValueDatabase {
|
|||
if !key.starts_with(b"!") {
|
||||
return None;
|
||||
}
|
||||
let mut parts = key.splitn(4, |&b| b == 0xff);
|
||||
let mut parts = key.splitn(4, |&b| b == 0xFF);
|
||||
let room_id = parts.next().unwrap();
|
||||
let word = parts.next().unwrap();
|
||||
let _pdu_id_room = parts.next().unwrap();
|
||||
let pdu_id_count = parts.next().unwrap();
|
||||
|
||||
let short_room_id = db
|
||||
.roomid_shortroomid
|
||||
.get(room_id)
|
||||
.unwrap()
|
||||
.expect("shortroomid should exist");
|
||||
let short_room_id =
|
||||
db.roomid_shortroomid.get(room_id).unwrap().expect("shortroomid should exist");
|
||||
let mut new_key = short_room_id;
|
||||
new_key.extend_from_slice(word);
|
||||
new_key.push(0xff);
|
||||
new_key.push(0xFF);
|
||||
new_key.extend_from_slice(pdu_id_count);
|
||||
Some((new_key, Vec::new()))
|
||||
})
|
||||
|
@ -784,8 +756,7 @@ impl KeyValueDatabase {
|
|||
if services().globals.database_version()? < 10 {
|
||||
// Add other direction for shortstatekeys
|
||||
for (statekey, shortstatekey) in db.statekey_shortstatekey.iter() {
|
||||
db.shortstatekey_statekey
|
||||
.insert(&shortstatekey, &statekey)?;
|
||||
db.shortstatekey_statekey.insert(&shortstatekey, &statekey)?;
|
||||
}
|
||||
|
||||
// Force E2EE device list updates so we can send them over federation
|
||||
|
@ -799,9 +770,7 @@ impl KeyValueDatabase {
|
|||
}
|
||||
|
||||
if services().globals.database_version()? < 11 {
|
||||
db.db
|
||||
.open_tree("userdevicesessionid_uiaarequest")?
|
||||
.clear()?;
|
||||
db.db.open_tree("userdevicesessionid_uiaarequest")?.clear()?;
|
||||
services().globals.bump_database_version(11)?;
|
||||
|
||||
warn!("Migration: 10 -> 11 finished");
|
||||
|
@ -809,43 +778,33 @@ impl KeyValueDatabase {
|
|||
|
||||
if services().globals.database_version()? < 12 {
|
||||
for username in services().users.list_local_users()? {
|
||||
let user = match UserId::parse_with_server_name(
|
||||
username.clone(),
|
||||
services().globals.server_name(),
|
||||
) {
|
||||
let user = match UserId::parse_with_server_name(username.clone(), services().globals.server_name())
|
||||
{
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
warn!("Invalid username {username}: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let raw_rules_list = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
&user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)
|
||||
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
|
||||
.unwrap()
|
||||
.expect("Username is invalid");
|
||||
|
||||
let mut account_data =
|
||||
serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
|
||||
let rules_list = &mut account_data.content.global;
|
||||
|
||||
//content rule
|
||||
{
|
||||
let content_rule_transformation =
|
||||
[".m.rules.contains_user_name", ".m.rule.contains_user_name"];
|
||||
let content_rule_transformation = [".m.rules.contains_user_name", ".m.rule.contains_user_name"];
|
||||
|
||||
let rule = rules_list.content.get(content_rule_transformation[0]);
|
||||
if rule.is_some() {
|
||||
let mut rule = rule.unwrap().clone();
|
||||
rule.rule_id = content_rule_transformation[1].to_owned();
|
||||
rules_list
|
||||
.content
|
||||
.shift_remove(content_rule_transformation[0]);
|
||||
rules_list.content.shift_remove(content_rule_transformation[0]);
|
||||
rules_list.content.insert(rule);
|
||||
}
|
||||
}
|
||||
|
@ -855,10 +814,7 @@ impl KeyValueDatabase {
|
|||
let underride_rule_transformation = [
|
||||
[".m.rules.call", ".m.rule.call"],
|
||||
[".m.rules.room_one_to_one", ".m.rule.room_one_to_one"],
|
||||
[
|
||||
".m.rules.encrypted_room_one_to_one",
|
||||
".m.rule.encrypted_room_one_to_one",
|
||||
],
|
||||
[".m.rules.encrypted_room_one_to_one", ".m.rule.encrypted_room_one_to_one"],
|
||||
[".m.rules.message", ".m.rule.message"],
|
||||
[".m.rules.encrypted", ".m.rule.encrypted"],
|
||||
];
|
||||
|
@ -887,38 +843,29 @@ impl KeyValueDatabase {
|
|||
warn!("Migration: 11 -> 12 finished");
|
||||
}
|
||||
|
||||
// This migration can be reused as-is anytime the server-default rules are updated.
|
||||
// This migration can be reused as-is anytime the server-default rules are
|
||||
// updated.
|
||||
if services().globals.database_version()? < 13 {
|
||||
for username in services().users.list_local_users()? {
|
||||
let user = match UserId::parse_with_server_name(
|
||||
username.clone(),
|
||||
services().globals.server_name(),
|
||||
) {
|
||||
let user = match UserId::parse_with_server_name(username.clone(), services().globals.server_name())
|
||||
{
|
||||
Ok(u) => u,
|
||||
Err(e) => {
|
||||
warn!("Invalid username {username}: {e}");
|
||||
continue;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let raw_rules_list = services()
|
||||
.account_data
|
||||
.get(
|
||||
None,
|
||||
&user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
)
|
||||
.get(None, &user, GlobalAccountDataEventType::PushRules.to_string().into())
|
||||
.unwrap()
|
||||
.expect("Username is invalid");
|
||||
|
||||
let mut account_data =
|
||||
serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
|
||||
let mut account_data = serde_json::from_str::<PushRulesEvent>(raw_rules_list.get()).unwrap();
|
||||
|
||||
let user_default_rules = ruma::push::Ruleset::server_default(&user);
|
||||
account_data
|
||||
.content
|
||||
.global
|
||||
.update_with_server_default(user_default_rules);
|
||||
account_data.content.global.update_with_server_default(user_default_rules);
|
||||
|
||||
services().account_data.update(
|
||||
None,
|
||||
|
@ -937,8 +884,8 @@ impl KeyValueDatabase {
|
|||
warn!("sha256_media feature flag is enabled, migrating legacy base64 file names to sha256 file names");
|
||||
// Move old media files to new names
|
||||
for (key, _) in db.mediaid_file.iter() {
|
||||
// we know that this method is deprecated, but we need to use it to migrate the old files
|
||||
// to the new location
|
||||
// we know that this method is deprecated, but we need to use it to migrate the
|
||||
// old files to the new location
|
||||
//
|
||||
// TODO: remove this once we're sure that all users have migrated
|
||||
#[allow(deprecated)]
|
||||
|
@ -957,7 +904,10 @@ impl KeyValueDatabase {
|
|||
|
||||
assert_eq!(
|
||||
services().globals.database_version().unwrap(),
|
||||
latest_database_version, "Failed asserting local database version {} is equal to known latest conduwuit database version {}", services().globals.database_version().unwrap(), latest_database_version
|
||||
latest_database_version,
|
||||
"Failed asserting local database version {} is equal to known latest conduwuit database version {}",
|
||||
services().globals.database_version().unwrap(),
|
||||
latest_database_version
|
||||
);
|
||||
|
||||
{
|
||||
|
@ -970,10 +920,7 @@ impl KeyValueDatabase {
|
|||
warn!(
|
||||
"User {} matches the following forbidden username patterns: {}",
|
||||
user_id.to_string(),
|
||||
matches
|
||||
.into_iter()
|
||||
.map(|x| &patterns.patterns()[x])
|
||||
.join(", ")
|
||||
matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -994,10 +941,7 @@ impl KeyValueDatabase {
|
|||
"Room with alias {} ({}) matches the following forbidden room name patterns: {}",
|
||||
room_alias,
|
||||
&room_id,
|
||||
matches
|
||||
.into_iter()
|
||||
.map(|x| &patterns.patterns()[x])
|
||||
.join(", ")
|
||||
matches.into_iter().map(|x| &patterns.patterns()[x]).join(", ")
|
||||
);
|
||||
}
|
||||
}
|
||||
|
@ -1011,9 +955,7 @@ impl KeyValueDatabase {
|
|||
latest_database_version
|
||||
);
|
||||
} else {
|
||||
services()
|
||||
.globals
|
||||
.bump_database_version(latest_database_version)?;
|
||||
services().globals.bump_database_version(latest_database_version)?;
|
||||
|
||||
// Create the admin room and server user on first run
|
||||
services().admin.create_admin_room().await?;
|
||||
|
@ -1031,16 +973,19 @@ impl KeyValueDatabase {
|
|||
match set_emergency_access() {
|
||||
Ok(pwd_set) => {
|
||||
if pwd_set {
|
||||
warn!("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!");
|
||||
services().admin.send_message(RoomMessageEventContent::text_plain("The Conduit account emergency password is set! Please unset it as soon as you finish admin account recovery!"));
|
||||
}
|
||||
}
|
||||
Err(e) => {
|
||||
error!(
|
||||
"Could not set the configured emergency password for the conduit user: {}",
|
||||
e
|
||||
warn!(
|
||||
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \
|
||||
account recovery!"
|
||||
);
|
||||
services().admin.send_message(RoomMessageEventContent::text_plain(
|
||||
"The Conduit account emergency password is set! Please unset it as soon as you finish admin \
|
||||
account recovery!",
|
||||
));
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Could not set the configured emergency password for the conduit user: {}", e);
|
||||
},
|
||||
};
|
||||
|
||||
services().sending.start_handler();
|
||||
|
@ -1079,12 +1024,8 @@ impl KeyValueDatabase {
|
|||
}
|
||||
|
||||
async fn try_handle_updates() -> Result<()> {
|
||||
let response = services()
|
||||
.globals
|
||||
.default_client()
|
||||
.get("https://pupbrain.dev/check-for-updates/stable")
|
||||
.send()
|
||||
.await?;
|
||||
let response =
|
||||
services().globals.default_client().get("https://pupbrain.dev/check-for-updates/stable").send().await?;
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct CheckForUpdatesResponseEntry {
|
||||
|
@ -1097,8 +1038,7 @@ impl KeyValueDatabase {
|
|||
updates: Vec<CheckForUpdatesResponseEntry>,
|
||||
}
|
||||
|
||||
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?)
|
||||
.map_err(|e| {
|
||||
let response = serde_json::from_str::<CheckForUpdatesResponse>(&response.text().await?).map_err(|e| {
|
||||
error!("Bad check for updates response: {e}");
|
||||
Error::BadServerResponse("Bad version check response")
|
||||
})?;
|
||||
|
@ -1108,17 +1048,13 @@ impl KeyValueDatabase {
|
|||
last_update_id = last_update_id.max(update.id);
|
||||
if update.id > services().globals.last_check_for_updates_id()? {
|
||||
error!("{}", update.message);
|
||||
services()
|
||||
.admin
|
||||
.send_message(RoomMessageEventContent::text_plain(format!(
|
||||
services().admin.send_message(RoomMessageEventContent::text_plain(format!(
|
||||
"@room: the following is a message from the conduwuit puppy. it was sent on '{}':\n\n{}",
|
||||
update.date, update.message
|
||||
)));
|
||||
}
|
||||
}
|
||||
services()
|
||||
.globals
|
||||
.update_check_for_updates_id(last_update_id)?;
|
||||
services().globals.update_check_for_updates_id(last_update_id)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -1129,8 +1065,7 @@ impl KeyValueDatabase {
|
|||
use tokio::signal::unix::{signal, SignalKind};
|
||||
use tokio::time::Instant;
|
||||
|
||||
let timer_interval =
|
||||
Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval));
|
||||
let timer_interval = Duration::from_secs(u64::from(services().globals.config.cleanup_second_interval));
|
||||
|
||||
fn perform_cleanup() {
|
||||
let start = Instant::now();
|
||||
|
@ -1176,9 +1111,7 @@ impl KeyValueDatabase {
|
|||
});
|
||||
}
|
||||
|
||||
pub async fn start_presence_handler(
|
||||
presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>,
|
||||
) {
|
||||
pub async fn start_presence_handler(presence_timer_receiver: mpsc::UnboundedReceiver<(OwnedUserId, Duration)>) {
|
||||
tokio::spawn(async move {
|
||||
match presence_handler(presence_timer_receiver).await {
|
||||
Ok(()) => warn!("Presence maintenance task finished"),
|
||||
|
@ -1188,15 +1121,13 @@ impl KeyValueDatabase {
|
|||
}
|
||||
}
|
||||
|
||||
/// Sets the emergency password and push rules for the @conduit account in case emergency password is set
|
||||
/// Sets the emergency password and push rules for the @conduit account in case
|
||||
/// emergency password is set
|
||||
fn set_emergency_access() -> Result<bool> {
|
||||
let conduit_user = UserId::parse_with_server_name("conduit", services().globals.server_name())
|
||||
.expect("@conduit:server_name is a valid UserId");
|
||||
|
||||
services().users.set_password(
|
||||
&conduit_user,
|
||||
services().globals.emergency_password().as_deref(),
|
||||
)?;
|
||||
services().users.set_password(&conduit_user, services().globals.emergency_password().as_deref())?;
|
||||
|
||||
let (ruleset, res) = match services().globals.emergency_password() {
|
||||
Some(_) => (Ruleset::server_default(&conduit_user), Ok(true)),
|
||||
|
@ -1208,7 +1139,9 @@ fn set_emergency_access() -> Result<bool> {
|
|||
&conduit_user,
|
||||
GlobalAccountDataEventType::PushRules.to_string().into(),
|
||||
&serde_json::to_value(&GlobalAccountDataEvent {
|
||||
content: PushRulesEventContent { global: ruleset },
|
||||
content: PushRulesEventContent {
|
||||
global: ruleset,
|
||||
},
|
||||
})
|
||||
.expect("to json value always works"),
|
||||
)?;
|
||||
|
|
|
@ -15,8 +15,5 @@ pub use utils::error::{Error, Result};
|
|||
pub static SERVICES: RwLock<Option<&'static Services<'static>>> = RwLock::new(None);
|
||||
|
||||
pub fn services() -> &'static Services<'static> {
|
||||
SERVICES
|
||||
.read()
|
||||
.unwrap()
|
||||
.expect("SERVICES should be initialized when this is called")
|
||||
SERVICES.read().unwrap().expect("SERVICES should be initialized when this is called")
|
||||
}
|
||||
|
|
266
src/main.rs
266
src/main.rs
|
@ -1,6 +1,6 @@
|
|||
use std::{
|
||||
fs::Permissions, future::Future, io, net::SocketAddr, os::unix::fs::PermissionsExt, path::Path,
|
||||
sync::atomic, time::Duration,
|
||||
fs::Permissions, future::Future, io, net::SocketAddr, os::unix::fs::PermissionsExt, path::Path, sync::atomic,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use axum::{
|
||||
|
@ -10,7 +10,11 @@ use axum::{
|
|||
Router,
|
||||
};
|
||||
use axum_server::{bind, bind_rustls, tls_rustls::RustlsConfig, Handle as ServerHandle};
|
||||
#[cfg(feature = "axum_dual_protocol")]
|
||||
use axum_server_dual_protocol::ServerExt;
|
||||
use clap::Parser;
|
||||
use conduit::api::{client_server, server_server};
|
||||
pub use conduit::*; // Re-export everything from the library crate
|
||||
use either::Either::{Left, Right};
|
||||
use figment::{
|
||||
providers::{Env, Format, Toml},
|
||||
|
@ -29,7 +33,14 @@ use ruma::api::{
|
|||
},
|
||||
IncomingRequest,
|
||||
};
|
||||
use tokio::{net::UnixListener, signal, sync::oneshot, task::JoinSet};
|
||||
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
|
||||
use tikv_jemallocator::Jemalloc;
|
||||
use tokio::{
|
||||
net::UnixListener,
|
||||
signal,
|
||||
sync::{oneshot, oneshot::Sender},
|
||||
task::JoinSet,
|
||||
};
|
||||
use tower::ServiceBuilder;
|
||||
use tower_http::{
|
||||
cors::{self, CorsLayer},
|
||||
|
@ -39,18 +50,6 @@ use tower_http::{
|
|||
use tracing::{debug, error, info, warn, Level};
|
||||
use tracing_subscriber::{prelude::*, EnvFilter};
|
||||
|
||||
use tokio::sync::oneshot::Sender;
|
||||
|
||||
use clap::Parser;
|
||||
|
||||
#[cfg(feature = "axum_dual_protocol")]
|
||||
use axum_server_dual_protocol::ServerExt;
|
||||
|
||||
pub use conduit::*; // Re-export everything from the library crate
|
||||
|
||||
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
|
||||
use tikv_jemallocator::Jemalloc;
|
||||
|
||||
#[cfg(all(not(target_env = "msvc"), feature = "jemalloc"))]
|
||||
#[global_allocator]
|
||||
static GLOBAL: Jemalloc = Jemalloc;
|
||||
|
@ -67,7 +66,8 @@ async fn main() {
|
|||
Figment::new()
|
||||
.merge(
|
||||
Toml::file(Env::var("CONDUIT_CONFIG").expect(
|
||||
"The CONDUIT_CONFIG environment variable was set but appears to be invalid. This should be set to the path to a valid TOML file, an empty string (for compatibility), or removed/unset entirely.",
|
||||
"The CONDUIT_CONFIG environment variable was set but appears to be invalid. This should be set to \
|
||||
the path to a valid TOML file, an empty string (for compatibility), or removed/unset entirely.",
|
||||
))
|
||||
.nested(),
|
||||
)
|
||||
|
@ -81,7 +81,7 @@ async fn main() {
|
|||
Err(e) => {
|
||||
eprintln!("It looks like your config is invalid. The following error occurred: {e}");
|
||||
return;
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
if config.allow_jaeger {
|
||||
|
@ -96,21 +96,16 @@ async fn main() {
|
|||
let filter_layer = match EnvFilter::try_new(&config.log) {
|
||||
Ok(s) => s,
|
||||
Err(e) => {
|
||||
eprintln!(
|
||||
"It looks like your log config is invalid. The following error occurred: {e}"
|
||||
);
|
||||
eprintln!("It looks like your log config is invalid. The following error occurred: {e}");
|
||||
EnvFilter::try_new("warn").unwrap()
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let subscriber = tracing_subscriber::Registry::default()
|
||||
.with(filter_layer)
|
||||
.with(telemetry);
|
||||
let subscriber = tracing_subscriber::Registry::default().with(filter_layer).with(telemetry);
|
||||
tracing::subscriber::set_global_default(subscriber).unwrap();
|
||||
} else if config.tracing_flame {
|
||||
let registry = tracing_subscriber::Registry::default();
|
||||
let (flame_layer, _guard) =
|
||||
tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap();
|
||||
let (flame_layer, _guard) = tracing_flame::FlameLayer::with_file("./tracing.folded").unwrap();
|
||||
let flame_layer = flame_layer.with_empty_samples(false);
|
||||
|
||||
let filter_layer = EnvFilter::new("trace,h2=off");
|
||||
|
@ -125,7 +120,7 @@ async fn main() {
|
|||
Err(e) => {
|
||||
eprintln!("It looks like your config is invalid. The following error occured while parsing it: {e}");
|
||||
EnvFilter::try_new("warn").unwrap()
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let subscriber = registry.with(filter_layer).with(fmt_layer);
|
||||
|
@ -172,17 +167,34 @@ async fn main() {
|
|||
if Path::new("/proc/vz").exists() /* Guest */ && !Path::new("/proc/bz").exists()
|
||||
/* Host */
|
||||
{
|
||||
error!("You are detected using OpenVZ with a loopback/localhost listening address of {}. If you are using OpenVZ for containers and you use NAT-based networking to communicate with the host and guest, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", config.address);
|
||||
error!(
|
||||
"You are detected using OpenVZ with a loopback/localhost listening address of {}. If you are using \
|
||||
OpenVZ for containers and you use NAT-based networking to communicate with the host and guest, this \
|
||||
will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.",
|
||||
config.address
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
if Path::new("/.dockerenv").exists() {
|
||||
error!("You are detected using Docker with a loopback/localhost listening address of {}. If you are using a reverse proxy on the host and require communication to conduwuit in the Docker container via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", config.address);
|
||||
error!(
|
||||
"You are detected using Docker with a loopback/localhost listening address of {}. If you are using a \
|
||||
reverse proxy on the host and require communication to conduwuit in the Docker container via \
|
||||
NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \
|
||||
you can ignore.",
|
||||
config.address
|
||||
);
|
||||
}
|
||||
|
||||
#[cfg(unix)]
|
||||
if Path::new("/run/.containerenv").exists() {
|
||||
error!("You are detected using Podman with a loopback/localhost listening address of {}. If you are using a reverse proxy on the host and require communication to conduwuit in the Podman container via NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, you can ignore.", config.address);
|
||||
error!(
|
||||
"You are detected using Podman with a loopback/localhost listening address of {}. If you are using a \
|
||||
reverse proxy on the host and require communication to conduwuit in the Podman container via \
|
||||
NAT-based networking, this will NOT work. Please change this to \"0.0.0.0\". If this is expected, \
|
||||
you can ignore.",
|
||||
config.address
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -208,18 +220,22 @@ async fn main() {
|
|||
|
||||
// check if user specified valid IP CIDR ranges on startup
|
||||
for cidr in services().globals.ip_range_denylist() {
|
||||
let _ = ipaddress::IPAddress::parse(cidr)
|
||||
.map_err(|e| error!("Error parsing specified IP CIDR range: {e}"));
|
||||
let _ = ipaddress::IPAddress::parse(cidr).map_err(|e| error!("Error parsing specified IP CIDR range: {e}"));
|
||||
}
|
||||
|
||||
if config.allow_registration
|
||||
&& !config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse
|
||||
&& config.registration_token.is_none()
|
||||
{
|
||||
error!("!! You have `allow_registration` enabled without a token configured in your config which means you are allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n
|
||||
If this is not the intended behaviour, please set a registration token with the `registration_token` config option.\n
|
||||
For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour you want, please set the following config option to true:
|
||||
`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`");
|
||||
error!(
|
||||
"!! You have `allow_registration` enabled without a token configured in your config which means you are \
|
||||
allowing ANYONE to register on your conduwuit instance without any 2nd-step (e.g. registration token).\n
|
||||
If this is not the intended behaviour, please set a registration token with the `registration_token` config \
|
||||
option.\n
|
||||
For security and safety reasons, conduwuit will shut down. If you are extra sure this is the desired behaviour \
|
||||
you want, please set the following config option to true:
|
||||
`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse`"
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
|
@ -227,8 +243,12 @@ async fn main() {
|
|||
&& config.yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse
|
||||
&& config.registration_token.is_none()
|
||||
{
|
||||
warn!("Open registration is enabled via setting `yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` and `allow_registration` to true without a registration token configured. You are expected to be aware of the risks now.\n
|
||||
If this is not the desired behaviour, please set a registration token.");
|
||||
warn!(
|
||||
"Open registration is enabled via setting \
|
||||
`yes_i_am_very_very_sure_i_want_an_open_registration_server_prone_to_abuse` and `allow_registration` to \
|
||||
true without a registration token configured. You are expected to be aware of the risks now.\n
|
||||
If this is not the desired behaviour, please set a registration token."
|
||||
);
|
||||
}
|
||||
|
||||
if config.allow_outgoing_presence && !config.allow_local_presence {
|
||||
|
@ -237,26 +257,33 @@ async fn main() {
|
|||
}
|
||||
|
||||
if config.allow_outgoing_presence {
|
||||
warn!("! Outgoing federated presence is not spec compliant due to relying on PDUs and EDUs combined.\nOutgoing presence will not be very reliable due to this and any issues with federated outgoing presence are very likely attributed to this issue.\nIncoming presence and local presence are unaffected.");
|
||||
warn!(
|
||||
"! Outgoing federated presence is not spec compliant due to relying on PDUs and EDUs combined.\nOutgoing \
|
||||
presence will not be very reliable due to this and any issues with federated outgoing presence are very \
|
||||
likely attributed to this issue.\nIncoming presence and local presence are unaffected."
|
||||
);
|
||||
}
|
||||
|
||||
if config
|
||||
.url_preview_domain_contains_allowlist
|
||||
.contains(&"*".to_owned())
|
||||
{
|
||||
warn!("All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". This opens up significant attack surface to your server. You are expected to be aware of the risks by doing this.");
|
||||
if config.url_preview_domain_contains_allowlist.contains(&"*".to_owned()) {
|
||||
warn!(
|
||||
"All URLs are allowed for URL previews via setting \"url_preview_domain_contains_allowlist\" to \"*\". \
|
||||
This opens up significant attack surface to your server. You are expected to be aware of the risks by \
|
||||
doing this."
|
||||
);
|
||||
}
|
||||
if config
|
||||
.url_preview_domain_explicit_allowlist
|
||||
.contains(&"*".to_owned())
|
||||
{
|
||||
warn!("All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". This opens up significant attack surface to your server. You are expected to be aware of the risks by doing this.");
|
||||
if config.url_preview_domain_explicit_allowlist.contains(&"*".to_owned()) {
|
||||
warn!(
|
||||
"All URLs are allowed for URL previews via setting \"url_preview_domain_explicit_allowlist\" to \"*\". \
|
||||
This opens up significant attack surface to your server. You are expected to be aware of the risks by \
|
||||
doing this."
|
||||
);
|
||||
}
|
||||
if config
|
||||
.url_preview_url_contains_allowlist
|
||||
.contains(&"*".to_owned())
|
||||
{
|
||||
warn!("All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This opens up significant attack surface to your server. You are expected to be aware of the risks by doing this.");
|
||||
if config.url_preview_url_contains_allowlist.contains(&"*".to_owned()) {
|
||||
warn!(
|
||||
"All URLs are allowed for URL previews via setting \"url_preview_url_contains_allowlist\" to \"*\". This \
|
||||
opens up significant attack surface to your server. You are expected to be aware of the risks by doing \
|
||||
this."
|
||||
);
|
||||
}
|
||||
|
||||
/* end ad-hoc config validation/checks */
|
||||
|
@ -266,8 +293,9 @@ async fn main() {
|
|||
error!("Critical error starting server: {}", e);
|
||||
};
|
||||
|
||||
// if server runs into critical error and shuts down, shut down the tracer provider if jaegar is used.
|
||||
// awaiting run_server() is a blocking call so putting this after is fine, but not the other options above.
|
||||
// if server runs into critical error and shuts down, shut down the tracer
|
||||
// provider if jaegar is used. awaiting run_server() is a blocking call so
|
||||
// putting this after is fine, but not the other options above.
|
||||
if config.allow_jaeger {
|
||||
opentelemetry::global::shutdown_tracer_provider();
|
||||
}
|
||||
|
@ -281,17 +309,9 @@ async fn run_server() -> io::Result<()> {
|
|||
// Left is only 1 value, so make a vec with 1 value only
|
||||
let port_vec = [port];
|
||||
|
||||
port_vec
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|port| SocketAddr::from((config.address, *port)))
|
||||
.collect::<Vec<_>>()
|
||||
}
|
||||
Right(ports) => ports
|
||||
.iter()
|
||||
.copied()
|
||||
.map(|port| SocketAddr::from((config.address, port)))
|
||||
.collect::<Vec<_>>(),
|
||||
port_vec.iter().copied().map(|port| SocketAddr::from((config.address, *port))).collect::<Vec<_>>()
|
||||
},
|
||||
Right(ports) => ports.iter().copied().map(|port| SocketAddr::from((config.address, port))).collect::<Vec<_>>(),
|
||||
};
|
||||
|
||||
let x_requested_with = HeaderName::from_static("x-requested-with");
|
||||
|
@ -336,17 +356,12 @@ async fn run_server() -> io::Result<()> {
|
|||
.max_age(Duration::from_secs(86400)),
|
||||
)
|
||||
.layer(DefaultBodyLimit::max(
|
||||
config
|
||||
.max_request_size
|
||||
.try_into()
|
||||
.expect("failed to convert max request size"),
|
||||
config.max_request_size.try_into().expect("failed to convert max request size"),
|
||||
));
|
||||
|
||||
let app = if cfg!(feature = "zstd_compression") && config.zstd_compression {
|
||||
debug!("zstd body compression is enabled");
|
||||
routes()
|
||||
.layer(middlewares.compression())
|
||||
.into_make_service()
|
||||
routes().layer(middlewares.compression()).into_make_service()
|
||||
} else {
|
||||
routes().layer(middlewares).into_make_service()
|
||||
};
|
||||
|
@ -371,9 +386,7 @@ async fn run_server() -> io::Result<()> {
|
|||
let octal_perms = u32::from_str_radix(&socket_perms, 8).unwrap();
|
||||
|
||||
let listener = UnixListener::bind(path.clone())?;
|
||||
tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms))
|
||||
.await
|
||||
.unwrap();
|
||||
tokio::fs::set_permissions(path, Permissions::from_mode(octal_perms)).await.unwrap();
|
||||
let socket = SocketIncoming::from_listener(listener);
|
||||
|
||||
#[cfg(feature = "systemd")]
|
||||
|
@ -395,12 +408,16 @@ async fn run_server() -> io::Result<()> {
|
|||
"Using direct TLS. Certificate path {} and certificate private key path {}",
|
||||
&tls.certs, &tls.key
|
||||
);
|
||||
info!("Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit directly with TLS.");
|
||||
info!(
|
||||
"Note: It is strongly recommended that you use a reverse proxy instead of running conduwuit \
|
||||
directly with TLS."
|
||||
);
|
||||
let conf = RustlsConfig::from_pem_file(&tls.certs, &tls.key).await?;
|
||||
|
||||
if cfg!(feature = "axum_dual_protocol") {
|
||||
info!(
|
||||
"conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This will only take affect if `dual_protocol` is enabled in `[global.tls]`"
|
||||
"conduwuit was built with axum_dual_protocol feature to listen on both HTTP and HTTPS. This \
|
||||
will only take affect if `dual_protocol` is enabled in `[global.tls]`"
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -418,11 +435,7 @@ async fn run_server() -> io::Result<()> {
|
|||
}
|
||||
} else {
|
||||
for addr in &addrs {
|
||||
join_set.spawn(
|
||||
bind_rustls(*addr, conf.clone())
|
||||
.handle(handle.clone())
|
||||
.serve(app.clone()),
|
||||
);
|
||||
join_set.spawn(bind_rustls(*addr, conf.clone()).handle(handle.clone()).serve(app.clone()));
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -431,18 +444,16 @@ async fn run_server() -> io::Result<()> {
|
|||
|
||||
if cfg!(feature = "axum_dual_protocol") && tls.dual_protocol {
|
||||
warn!(
|
||||
"Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too (insecure!)",
|
||||
"Listening on {:?} with TLS certificate {} and supporting plain text (HTTP) connections too \
|
||||
(insecure!)",
|
||||
addrs, &tls.certs
|
||||
);
|
||||
} else {
|
||||
info!(
|
||||
"Listening on {:?} with TLS certificate {}",
|
||||
addrs, &tls.certs
|
||||
);
|
||||
info!("Listening on {:?} with TLS certificate {}", addrs, &tls.certs);
|
||||
}
|
||||
|
||||
join_set.join_next().await;
|
||||
}
|
||||
},
|
||||
None => {
|
||||
let mut join_set = JoinSet::new();
|
||||
for addr in &addrs {
|
||||
|
@ -454,7 +465,7 @@ async fn run_server() -> io::Result<()> {
|
|||
|
||||
info!("Listening on {:?}", addrs);
|
||||
join_set.join_next().await;
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -462,20 +473,16 @@ async fn run_server() -> io::Result<()> {
|
|||
}
|
||||
|
||||
async fn spawn_task<B: Send + 'static>(
|
||||
req: axum::http::Request<B>,
|
||||
next: axum::middleware::Next<B>,
|
||||
req: axum::http::Request<B>, next: axum::middleware::Next<B>,
|
||||
) -> std::result::Result<axum::response::Response, StatusCode> {
|
||||
if services().globals.shutdown.load(atomic::Ordering::Relaxed) {
|
||||
return Err(StatusCode::SERVICE_UNAVAILABLE);
|
||||
}
|
||||
tokio::spawn(next.run(req))
|
||||
.await
|
||||
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
tokio::spawn(next.run(req)).await.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)
|
||||
}
|
||||
|
||||
async fn unrecognized_method<B: Send + 'static>(
|
||||
req: axum::http::Request<B>,
|
||||
next: axum::middleware::Next<B>,
|
||||
req: axum::http::Request<B>, next: axum::middleware::Next<B>,
|
||||
) -> std::result::Result<axum::response::Response, StatusCode> {
|
||||
let method = req.method().clone();
|
||||
let uri = req.uri().clone();
|
||||
|
@ -637,10 +644,7 @@ fn routes() -> Router {
|
|||
.ruma_route(client_server::get_relating_events_route)
|
||||
.ruma_route(client_server::get_hierarchy_route)
|
||||
.ruma_route(server_server::get_server_version_route)
|
||||
.route(
|
||||
"/_matrix/key/v2/server",
|
||||
get(server_server::get_server_keys_route),
|
||||
)
|
||||
.route("/_matrix/key/v2/server", get(server_server::get_server_keys_route))
|
||||
.route(
|
||||
"/_matrix/key/v2/server/:key_id",
|
||||
get(server_server::get_server_keys_deprecated_route),
|
||||
|
@ -663,35 +667,18 @@ fn routes() -> Router {
|
|||
.ruma_route(server_server::get_profile_information_route)
|
||||
.ruma_route(server_server::get_keys_route)
|
||||
.ruma_route(server_server::claim_keys_route)
|
||||
.route(
|
||||
"/_matrix/client/r0/rooms/:room_id/initialSync",
|
||||
get(initial_sync),
|
||||
)
|
||||
.route(
|
||||
"/_matrix/client/v3/rooms/:room_id/initialSync",
|
||||
get(initial_sync),
|
||||
)
|
||||
.route(
|
||||
"/client/server.json",
|
||||
get(client_server::syncv3_client_server_json),
|
||||
)
|
||||
.route(
|
||||
"/.well-known/matrix/client",
|
||||
get(client_server::well_known_client_route),
|
||||
)
|
||||
.route(
|
||||
"/.well-known/matrix/server",
|
||||
get(server_server::well_known_server_route),
|
||||
)
|
||||
.route("/_matrix/client/r0/rooms/:room_id/initialSync", get(initial_sync))
|
||||
.route("/_matrix/client/v3/rooms/:room_id/initialSync", get(initial_sync))
|
||||
.route("/client/server.json", get(client_server::syncv3_client_server_json))
|
||||
.route("/.well-known/matrix/client", get(client_server::well_known_client_route))
|
||||
.route("/.well-known/matrix/server", get(server_server::well_known_server_route))
|
||||
.route("/", get(it_works))
|
||||
.fallback(not_found)
|
||||
}
|
||||
|
||||
async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> {
|
||||
let ctrl_c = async {
|
||||
signal::ctrl_c()
|
||||
.await
|
||||
.expect("failed to install Ctrl+C handler");
|
||||
signal::ctrl_c().await.expect("failed to install Ctrl+C handler");
|
||||
};
|
||||
|
||||
#[cfg(unix)]
|
||||
|
@ -721,19 +708,23 @@ async fn shutdown_signal(handle: ServerHandle, tx: Sender<()>) -> Result<()> {
|
|||
#[cfg(feature = "systemd")]
|
||||
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::Stopping]);
|
||||
|
||||
tx.send(()).expect("failed sending shutdown transaction to oneshot channel (this is unlikely a conduwuit bug and more so your system may not be in an okay/ideal state.)");
|
||||
tx.send(()).expect(
|
||||
"failed sending shutdown transaction to oneshot channel (this is unlikely a conduwuit bug and more so your \
|
||||
system may not be in an okay/ideal state.)",
|
||||
);
|
||||
|
||||
if shutdown_time_elapsed.elapsed() >= Duration::from_secs(60) && cfg!(feature = "systemd") {
|
||||
warn!("Still shutting down after 60 seconds since receiving shutdown signal, asking systemd for more time (+120 seconds). Remaining connections: {}", handle.connection_count());
|
||||
warn!(
|
||||
"Still shutting down after 60 seconds since receiving shutdown signal, asking systemd for more time (+120 \
|
||||
seconds). Remaining connections: {}",
|
||||
handle.connection_count()
|
||||
);
|
||||
|
||||
#[cfg(feature = "systemd")]
|
||||
let _ = sd_notify::notify(true, &[sd_notify::NotifyState::ExtendTimeoutUsec(120)]);
|
||||
}
|
||||
|
||||
warn!(
|
||||
"Time took to shutdown: {:?} seconds",
|
||||
shutdown_time_elapsed.elapsed()
|
||||
);
|
||||
warn!("Time took to shutdown: {:?} seconds", shutdown_time_elapsed.elapsed());
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
@ -744,15 +735,10 @@ async fn not_found(uri: Uri) -> impl IntoResponse {
|
|||
}
|
||||
|
||||
async fn initial_sync(_uri: Uri) -> impl IntoResponse {
|
||||
Error::BadRequest(
|
||||
ErrorKind::GuestAccessForbidden,
|
||||
"Guest access not implemented",
|
||||
)
|
||||
Error::BadRequest(ErrorKind::GuestAccessForbidden, "Guest access not implemented")
|
||||
}
|
||||
|
||||
async fn it_works() -> &'static str {
|
||||
"hewwo from conduwuit woof!"
|
||||
}
|
||||
async fn it_works() -> &'static str { "hewwo from conduwuit woof!" }
|
||||
|
||||
trait RouterExt {
|
||||
fn ruma_route<H, T>(self, handler: H) -> Self
|
||||
|
@ -773,8 +759,8 @@ impl RouterExt for Router {
|
|||
|
||||
pub trait RumaHandler<T> {
|
||||
// Can't transform to a handler without boxing or relying on the nightly-only
|
||||
// impl-trait-in-traits feature. Moving a small amount of extra logic into the trait
|
||||
// allows bypassing both.
|
||||
// impl-trait-in-traits feature. Moving a small amount of extra logic into the
|
||||
// trait allows bypassing both.
|
||||
fn add_to_router(self, router: Router) -> Router;
|
||||
}
|
||||
|
||||
|
|
|
@ -1,35 +1,28 @@
|
|||
use std::collections::HashMap;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
/// Places one event in the account data of the user and removes the
|
||||
/// previous entry.
|
||||
fn update(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()>;
|
||||
|
||||
/// Searches the account data for a specific kind.
|
||||
fn get(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
kind: RoomAccountDataEventType,
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, kind: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>>;
|
||||
|
||||
/// Returns all changes to the account data that happened after `since`.
|
||||
fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
since: u64,
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>>;
|
||||
}
|
||||
|
|
|
@ -1,15 +1,14 @@
|
|||
mod data;
|
||||
|
||||
pub(crate) use data::Data;
|
||||
use std::collections::HashMap;
|
||||
|
||||
pub(crate) use data::Data;
|
||||
use ruma::{
|
||||
events::{AnyEphemeralRoomEvent, RoomAccountDataEventType},
|
||||
serde::Raw,
|
||||
RoomId, UserId,
|
||||
};
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub struct Service {
|
||||
|
@ -17,13 +16,11 @@ pub struct Service {
|
|||
}
|
||||
|
||||
impl Service {
|
||||
/// Places one event in the account data of the user and removes the previous entry.
|
||||
/// Places one event in the account data of the user and removes the
|
||||
/// previous entry.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type, data))]
|
||||
pub fn update(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||
data: &serde_json::Value,
|
||||
) -> Result<()> {
|
||||
self.db.update(room_id, user_id, event_type, data)
|
||||
|
@ -32,10 +29,7 @@ impl Service {
|
|||
/// Searches the account data for a specific kind.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, event_type))]
|
||||
pub fn get(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
event_type: RoomAccountDataEventType,
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, event_type: RoomAccountDataEventType,
|
||||
) -> Result<Option<Box<serde_json::value::RawValue>>> {
|
||||
self.db.get(room_id, user_id, event_type)
|
||||
}
|
||||
|
@ -43,10 +37,7 @@ impl Service {
|
|||
/// Returns all changes to the account data that happened after `since`.
|
||||
#[tracing::instrument(skip(self, room_id, user_id, since))]
|
||||
pub fn changes_since(
|
||||
&self,
|
||||
room_id: Option<&RoomId>,
|
||||
user_id: &UserId,
|
||||
since: u64,
|
||||
&self, room_id: Option<&RoomId>, user_id: &UserId, since: u64,
|
||||
) -> Result<HashMap<RoomAccountDataEventType, Raw<AnyEphemeralRoomEvent>>> {
|
||||
self.db.changes_since(room_id, user_id, since)
|
||||
}
|
||||
|
|
File diff suppressed because it is too large
Load diff
|
@ -11,9 +11,7 @@ pub struct Service {
|
|||
|
||||
impl Service {
|
||||
/// Registers an appservice and returns the ID to the caller
|
||||
pub fn register_appservice(&self, yaml: Registration) -> Result<String> {
|
||||
self.db.register_appservice(yaml)
|
||||
}
|
||||
pub fn register_appservice(&self, yaml: Registration) -> Result<String> { self.db.register_appservice(yaml) }
|
||||
|
||||
/// Remove an appservice registration
|
||||
///
|
||||
|
@ -24,15 +22,9 @@ impl Service {
|
|||
self.db.unregister_appservice(service_name)
|
||||
}
|
||||
|
||||
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> {
|
||||
self.db.get_registration(id)
|
||||
}
|
||||
pub fn get_registration(&self, id: &str) -> Result<Option<Registration>> { self.db.get_registration(id) }
|
||||
|
||||
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> {
|
||||
self.db.iter_ids()
|
||||
}
|
||||
pub fn iter_ids(&self) -> Result<impl Iterator<Item = Result<String>> + '_> { self.db.iter_ids() }
|
||||
|
||||
pub fn all(&self) -> Result<Vec<(String, Registration)>> {
|
||||
self.db.all()
|
||||
}
|
||||
pub fn all(&self) -> Result<Vec<(String, Registration)>> { self.db.all() }
|
||||
}
|
||||
|
|
|
@ -22,16 +22,12 @@ pub trait Data: Send + Sync {
|
|||
fn load_keypair(&self) -> Result<Ed25519KeyPair>;
|
||||
fn remove_keypair(&self) -> Result<()>;
|
||||
fn add_signing_key(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
new_keys: ServerSigningKeys,
|
||||
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
||||
fn signing_keys_for(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||
/// for the server.
|
||||
fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>>;
|
||||
fn database_version(&self) -> Result<u64>;
|
||||
fn bump_database_version(&self, new_version: u64) -> Result<()>;
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ use std::{
|
|||
|
||||
use argon2::Argon2;
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
pub use data::Data;
|
||||
use futures_util::FutureExt;
|
||||
use hyper::{
|
||||
client::connect::dns::{GaiResolver, Name},
|
||||
|
@ -27,21 +28,16 @@ use ruma::{
|
|||
client::sync::sync_events,
|
||||
federation::discovery::{ServerSigningKeys, VerifyKey},
|
||||
},
|
||||
DeviceId, RoomVersionId, ServerName, UserId,
|
||||
};
|
||||
use ruma::{
|
||||
serde::Base64, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName,
|
||||
OwnedServerSigningKeyId, OwnedUserId,
|
||||
serde::Base64,
|
||||
DeviceId, OwnedDeviceId, OwnedEventId, OwnedRoomId, OwnedServerName, OwnedServerSigningKeyId, OwnedUserId,
|
||||
RoomVersionId, ServerName, UserId,
|
||||
};
|
||||
use sha2::Digest;
|
||||
use tokio::sync::{broadcast, watch::Receiver, Mutex as TokioMutex, Semaphore};
|
||||
use tracing::{error, info};
|
||||
use trust_dns_resolver::TokioAsyncResolver;
|
||||
|
||||
pub use data::Data;
|
||||
|
||||
use crate::api::server_server::FedDest;
|
||||
use crate::{services, Config, Error, Result};
|
||||
use crate::{api::server_server::FedDest, services, Config, Error, Result};
|
||||
|
||||
mod data;
|
||||
|
||||
|
@ -83,9 +79,11 @@ pub struct Service<'a> {
|
|||
pub argon: Argon2<'a>,
|
||||
}
|
||||
|
||||
/// Handles "rotation" of long-polling requests. "Rotation" in this context is similar to "rotation" of log files and the like.
|
||||
/// Handles "rotation" of long-polling requests. "Rotation" in this context is
|
||||
/// similar to "rotation" of log files and the like.
|
||||
///
|
||||
/// This is utilized to have sync workers return early and release read locks on the database.
|
||||
/// This is utilized to have sync workers return early and release read locks on
|
||||
/// the database.
|
||||
pub(crate) struct RotationHandler(broadcast::Sender<()>, ());
|
||||
|
||||
impl RotationHandler {
|
||||
|
@ -102,15 +100,11 @@ impl RotationHandler {
|
|||
}
|
||||
}
|
||||
|
||||
pub fn fire(&self) {
|
||||
let _ = self.0.send(());
|
||||
}
|
||||
pub fn fire(&self) { let _ = self.0.send(()); }
|
||||
}
|
||||
|
||||
impl Default for RotationHandler {
|
||||
fn default() -> Self {
|
||||
Self::new()
|
||||
}
|
||||
fn default() -> Self { Self::new() }
|
||||
}
|
||||
|
||||
struct Resolver {
|
||||
|
@ -162,15 +156,13 @@ impl Service<'_> {
|
|||
error!("Keypair invalid. Deleting...");
|
||||
db.remove_keypair()?;
|
||||
return Err(e);
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
let tls_name_override = Arc::new(RwLock::new(TlsNameMap::new()));
|
||||
|
||||
let jwt_decoding_key = config
|
||||
.jwt_secret
|
||||
.as_ref()
|
||||
.map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
|
||||
let jwt_decoding_key =
|
||||
config.jwt_secret.as_ref().map(|secret| jsonwebtoken::DecodingKey::from_secret(secret.as_bytes()));
|
||||
|
||||
let url_preview_client = url_preview_reqwest_client_builder(&config)?.build()?;
|
||||
let default_client = reqwest_client_builder(&config)?.build()?;
|
||||
|
@ -205,10 +197,7 @@ impl Service<'_> {
|
|||
config,
|
||||
keypair: Arc::new(keypair),
|
||||
dns_resolver: TokioAsyncResolver::tokio_from_system_conf().map_err(|e| {
|
||||
error!(
|
||||
"Failed to set up trust dns resolver with system config: {}",
|
||||
e
|
||||
);
|
||||
error!("Failed to set up trust dns resolver with system config: {}", e);
|
||||
Error::bad_config("Failed to set up trust dns resolver with system config.")
|
||||
})?,
|
||||
actual_destination_cache: Arc::new(RwLock::new(WellKnownMap::new())),
|
||||
|
@ -236,10 +225,7 @@ impl Service<'_> {
|
|||
|
||||
fs::create_dir_all(s.get_media_folder())?;
|
||||
|
||||
if !s
|
||||
.supported_room_versions()
|
||||
.contains(&s.config.default_room_version)
|
||||
{
|
||||
if !s.supported_room_versions().contains(&s.config.default_room_version) {
|
||||
error!(config=?s.config.default_room_version, fallback=?crate::config::default_default_room_version(), "Room version in config isn't supported, falling back to default version");
|
||||
s.config.default_room_version = crate::config::default_default_room_version();
|
||||
};
|
||||
|
@ -248,12 +234,11 @@ impl Service<'_> {
|
|||
}
|
||||
|
||||
/// Returns this server's keypair.
|
||||
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair {
|
||||
&self.keypair
|
||||
}
|
||||
pub fn keypair(&self) -> &ruma::signatures::Ed25519KeyPair { &self.keypair }
|
||||
|
||||
/// Returns a reqwest client which can be used to send requests for URL previews
|
||||
/// This is the same as `default_client()` except a redirect policy of max 2 is set
|
||||
/// Returns a reqwest client which can be used to send requests for URL
|
||||
/// previews This is the same as `default_client()` except a redirect policy
|
||||
/// of max 2 is set
|
||||
pub fn url_preview_client(&self) -> reqwest::Client {
|
||||
// Client is cheap to clone (Arc wrapper) and avoids lifetime issues
|
||||
self.url_preview_client.clone()
|
||||
|
@ -272,60 +257,36 @@ impl Service<'_> {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn next_count(&self) -> Result<u64> {
|
||||
self.db.next_count()
|
||||
}
|
||||
pub fn next_count(&self) -> Result<u64> { self.db.next_count() }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn current_count(&self) -> Result<u64> {
|
||||
self.db.current_count()
|
||||
}
|
||||
pub fn current_count(&self) -> Result<u64> { self.db.current_count() }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn last_check_for_updates_id(&self) -> Result<u64> {
|
||||
self.db.last_check_for_updates_id()
|
||||
}
|
||||
pub fn last_check_for_updates_id(&self) -> Result<u64> { self.db.last_check_for_updates_id() }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> {
|
||||
self.db.update_check_for_updates_id(id)
|
||||
}
|
||||
pub fn update_check_for_updates_id(&self, id: u64) -> Result<()> { self.db.update_check_for_updates_id(id) }
|
||||
|
||||
pub async fn watch(&self, user_id: &UserId, device_id: &DeviceId) -> Result<()> {
|
||||
self.db.watch(user_id, device_id).await
|
||||
}
|
||||
|
||||
pub fn cleanup(&self) -> Result<()> {
|
||||
self.db.cleanup()
|
||||
}
|
||||
pub fn cleanup(&self) -> Result<()> { self.db.cleanup() }
|
||||
|
||||
pub fn server_name(&self) -> &ServerName {
|
||||
self.config.server_name.as_ref()
|
||||
}
|
||||
pub fn server_name(&self) -> &ServerName { self.config.server_name.as_ref() }
|
||||
|
||||
pub fn max_request_size(&self) -> u32 {
|
||||
self.config.max_request_size
|
||||
}
|
||||
pub fn max_request_size(&self) -> u32 { self.config.max_request_size }
|
||||
|
||||
pub fn max_fetch_prev_events(&self) -> u16 {
|
||||
self.config.max_fetch_prev_events
|
||||
}
|
||||
pub fn max_fetch_prev_events(&self) -> u16 { self.config.max_fetch_prev_events }
|
||||
|
||||
pub fn allow_registration(&self) -> bool {
|
||||
self.config.allow_registration
|
||||
}
|
||||
pub fn allow_registration(&self) -> bool { self.config.allow_registration }
|
||||
|
||||
pub fn allow_guest_registration(&self) -> bool {
|
||||
self.config.allow_guest_registration
|
||||
}
|
||||
pub fn allow_guest_registration(&self) -> bool { self.config.allow_guest_registration }
|
||||
|
||||
pub fn allow_encryption(&self) -> bool {
|
||||
self.config.allow_encryption
|
||||
}
|
||||
pub fn allow_encryption(&self) -> bool { self.config.allow_encryption }
|
||||
|
||||
pub fn allow_federation(&self) -> bool {
|
||||
self.config.allow_federation
|
||||
}
|
||||
pub fn allow_federation(&self) -> bool { self.config.allow_federation }
|
||||
|
||||
pub fn allow_public_room_directory_over_federation(&self) -> bool {
|
||||
self.config.allow_public_room_directory_over_federation
|
||||
|
@ -335,73 +296,39 @@ impl Service<'_> {
|
|||
self.config.allow_public_room_directory_without_auth
|
||||
}
|
||||
|
||||
pub fn allow_device_name_federation(&self) -> bool {
|
||||
self.config.allow_device_name_federation
|
||||
}
|
||||
pub fn allow_device_name_federation(&self) -> bool { self.config.allow_device_name_federation }
|
||||
|
||||
pub fn allow_room_creation(&self) -> bool {
|
||||
self.config.allow_room_creation
|
||||
}
|
||||
pub fn allow_room_creation(&self) -> bool { self.config.allow_room_creation }
|
||||
|
||||
pub fn allow_unstable_room_versions(&self) -> bool {
|
||||
self.config.allow_unstable_room_versions
|
||||
}
|
||||
pub fn allow_unstable_room_versions(&self) -> bool { self.config.allow_unstable_room_versions }
|
||||
|
||||
pub fn default_room_version(&self) -> RoomVersionId {
|
||||
self.config.default_room_version.clone()
|
||||
}
|
||||
pub fn default_room_version(&self) -> RoomVersionId { self.config.default_room_version.clone() }
|
||||
|
||||
pub fn new_user_displayname_suffix(&self) -> &String {
|
||||
&self.config.new_user_displayname_suffix
|
||||
}
|
||||
pub fn new_user_displayname_suffix(&self) -> &String { &self.config.new_user_displayname_suffix }
|
||||
|
||||
pub fn allow_check_for_updates(&self) -> bool {
|
||||
self.config.allow_check_for_updates
|
||||
}
|
||||
pub fn allow_check_for_updates(&self) -> bool { self.config.allow_check_for_updates }
|
||||
|
||||
pub fn trusted_servers(&self) -> &[OwnedServerName] {
|
||||
&self.config.trusted_servers
|
||||
}
|
||||
pub fn trusted_servers(&self) -> &[OwnedServerName] { &self.config.trusted_servers }
|
||||
|
||||
pub fn query_trusted_key_servers_first(&self) -> bool {
|
||||
self.config.query_trusted_key_servers_first
|
||||
}
|
||||
pub fn query_trusted_key_servers_first(&self) -> bool { self.config.query_trusted_key_servers_first }
|
||||
|
||||
pub fn dns_resolver(&self) -> &TokioAsyncResolver {
|
||||
&self.dns_resolver
|
||||
}
|
||||
pub fn dns_resolver(&self) -> &TokioAsyncResolver { &self.dns_resolver }
|
||||
|
||||
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> {
|
||||
self.jwt_decoding_key.as_ref()
|
||||
}
|
||||
pub fn jwt_decoding_key(&self) -> Option<&jsonwebtoken::DecodingKey> { self.jwt_decoding_key.as_ref() }
|
||||
|
||||
pub fn turn_password(&self) -> &String {
|
||||
&self.config.turn_password
|
||||
}
|
||||
pub fn turn_password(&self) -> &String { &self.config.turn_password }
|
||||
|
||||
pub fn turn_ttl(&self) -> u64 {
|
||||
self.config.turn_ttl
|
||||
}
|
||||
pub fn turn_ttl(&self) -> u64 { self.config.turn_ttl }
|
||||
|
||||
pub fn turn_uris(&self) -> &[String] {
|
||||
&self.config.turn_uris
|
||||
}
|
||||
pub fn turn_uris(&self) -> &[String] { &self.config.turn_uris }
|
||||
|
||||
pub fn turn_username(&self) -> &String {
|
||||
&self.config.turn_username
|
||||
}
|
||||
pub fn turn_username(&self) -> &String { &self.config.turn_username }
|
||||
|
||||
pub fn turn_secret(&self) -> &String {
|
||||
&self.config.turn_secret
|
||||
}
|
||||
pub fn turn_secret(&self) -> &String { &self.config.turn_secret }
|
||||
|
||||
pub fn notification_push_path(&self) -> &String {
|
||||
&self.config.notification_push_path
|
||||
}
|
||||
pub fn notification_push_path(&self) -> &String { &self.config.notification_push_path }
|
||||
|
||||
pub fn emergency_password(&self) -> &Option<String> {
|
||||
&self.config.emergency_password
|
||||
}
|
||||
pub fn emergency_password(&self) -> &Option<String> { &self.config.emergency_password }
|
||||
|
||||
pub fn url_preview_domain_contains_allowlist(&self) -> &Vec<String> {
|
||||
&self.config.url_preview_domain_contains_allowlist
|
||||
|
@ -411,77 +338,41 @@ impl Service<'_> {
|
|||
&self.config.url_preview_domain_explicit_allowlist
|
||||
}
|
||||
|
||||
pub fn url_preview_url_contains_allowlist(&self) -> &Vec<String> {
|
||||
&self.config.url_preview_url_contains_allowlist
|
||||
}
|
||||
pub fn url_preview_url_contains_allowlist(&self) -> &Vec<String> { &self.config.url_preview_url_contains_allowlist }
|
||||
|
||||
pub fn url_preview_max_spider_size(&self) -> usize {
|
||||
self.config.url_preview_max_spider_size
|
||||
}
|
||||
pub fn url_preview_max_spider_size(&self) -> usize { self.config.url_preview_max_spider_size }
|
||||
|
||||
pub fn url_preview_check_root_domain(&self) -> bool {
|
||||
self.config.url_preview_check_root_domain
|
||||
}
|
||||
pub fn url_preview_check_root_domain(&self) -> bool { self.config.url_preview_check_root_domain }
|
||||
|
||||
pub fn forbidden_room_names(&self) -> &RegexSet {
|
||||
&self.config.forbidden_room_names
|
||||
}
|
||||
pub fn forbidden_room_names(&self) -> &RegexSet { &self.config.forbidden_room_names }
|
||||
|
||||
pub fn forbidden_usernames(&self) -> &RegexSet {
|
||||
&self.config.forbidden_usernames
|
||||
}
|
||||
pub fn forbidden_usernames(&self) -> &RegexSet { &self.config.forbidden_usernames }
|
||||
|
||||
pub fn allow_local_presence(&self) -> bool {
|
||||
self.config.allow_local_presence
|
||||
}
|
||||
pub fn allow_local_presence(&self) -> bool { self.config.allow_local_presence }
|
||||
|
||||
pub fn allow_incoming_presence(&self) -> bool {
|
||||
self.config.allow_incoming_presence
|
||||
}
|
||||
pub fn allow_incoming_presence(&self) -> bool { self.config.allow_incoming_presence }
|
||||
|
||||
pub fn allow_outgoing_presence(&self) -> bool {
|
||||
self.config.allow_outgoing_presence
|
||||
}
|
||||
pub fn allow_outgoing_presence(&self) -> bool { self.config.allow_outgoing_presence }
|
||||
|
||||
pub fn presence_idle_timeout_s(&self) -> u64 {
|
||||
self.config.presence_idle_timeout_s
|
||||
}
|
||||
pub fn presence_idle_timeout_s(&self) -> u64 { self.config.presence_idle_timeout_s }
|
||||
|
||||
pub fn presence_offline_timeout_s(&self) -> u64 {
|
||||
self.config.presence_offline_timeout_s
|
||||
}
|
||||
pub fn presence_offline_timeout_s(&self) -> u64 { self.config.presence_offline_timeout_s }
|
||||
|
||||
pub fn rocksdb_log_level(&self) -> &String {
|
||||
&self.config.rocksdb_log_level
|
||||
}
|
||||
pub fn rocksdb_log_level(&self) -> &String { &self.config.rocksdb_log_level }
|
||||
|
||||
pub fn rocksdb_max_log_file_size(&self) -> usize {
|
||||
self.config.rocksdb_max_log_file_size
|
||||
}
|
||||
pub fn rocksdb_max_log_file_size(&self) -> usize { self.config.rocksdb_max_log_file_size }
|
||||
|
||||
pub fn rocksdb_log_time_to_roll(&self) -> usize {
|
||||
self.config.rocksdb_log_time_to_roll
|
||||
}
|
||||
pub fn rocksdb_log_time_to_roll(&self) -> usize { self.config.rocksdb_log_time_to_roll }
|
||||
|
||||
pub fn rocksdb_optimize_for_spinning_disks(&self) -> bool {
|
||||
self.config.rocksdb_optimize_for_spinning_disks
|
||||
}
|
||||
pub fn rocksdb_optimize_for_spinning_disks(&self) -> bool { self.config.rocksdb_optimize_for_spinning_disks }
|
||||
|
||||
pub fn rocksdb_parallelism_threads(&self) -> usize {
|
||||
self.config.rocksdb_parallelism_threads
|
||||
}
|
||||
pub fn rocksdb_parallelism_threads(&self) -> usize { self.config.rocksdb_parallelism_threads }
|
||||
|
||||
pub fn prevent_media_downloads_from(&self) -> &[OwnedServerName] {
|
||||
&self.config.prevent_media_downloads_from
|
||||
}
|
||||
pub fn prevent_media_downloads_from(&self) -> &[OwnedServerName] { &self.config.prevent_media_downloads_from }
|
||||
|
||||
pub fn ip_range_denylist(&self) -> &[String] {
|
||||
&self.config.ip_range_denylist
|
||||
}
|
||||
pub fn ip_range_denylist(&self) -> &[String] { &self.config.ip_range_denylist }
|
||||
|
||||
pub fn block_non_admin_invites(&self) -> bool {
|
||||
self.config.block_non_admin_invites
|
||||
}
|
||||
pub fn block_non_admin_invites(&self) -> bool { self.config.block_non_admin_invites }
|
||||
|
||||
pub fn supported_room_versions(&self) -> Vec<RoomVersionId> {
|
||||
let mut room_versions: Vec<RoomVersionId> = vec![];
|
||||
|
@ -492,24 +383,22 @@ impl Service<'_> {
|
|||
room_versions
|
||||
}
|
||||
|
||||
/// TODO: the key valid until timestamp (`valid_until_ts`) is only honored in room version > 4
|
||||
/// TODO: the key valid until timestamp (`valid_until_ts`) is only honored
|
||||
/// in room version > 4
|
||||
///
|
||||
/// Remove the outdated keys and insert the new ones.
|
||||
///
|
||||
/// This doesn't actually check that the keys provided are newer than the old set.
|
||||
/// This doesn't actually check that the keys provided are newer than the
|
||||
/// old set.
|
||||
pub fn add_signing_key(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
new_keys: ServerSigningKeys,
|
||||
&self, origin: &ServerName, new_keys: ServerSigningKeys,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
self.db.add_signing_key(origin, new_keys)
|
||||
}
|
||||
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found for the server.
|
||||
pub fn signing_keys_for(
|
||||
&self,
|
||||
origin: &ServerName,
|
||||
) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
/// This returns an empty `Ok(BTreeMap<..>)` when there are no keys found
|
||||
/// for the server.
|
||||
pub fn signing_keys_for(&self, origin: &ServerName) -> Result<BTreeMap<OwnedServerSigningKeyId, VerifyKey>> {
|
||||
let mut keys = self.db.signing_keys_for(origin)?;
|
||||
if origin == self.server_name() {
|
||||
keys.insert(
|
||||
|
@ -525,13 +414,9 @@ impl Service<'_> {
|
|||
Ok(keys)
|
||||
}
|
||||
|
||||
pub fn database_version(&self) -> Result<u64> {
|
||||
self.db.database_version()
|
||||
}
|
||||
pub fn database_version(&self) -> Result<u64> { self.db.database_version() }
|
||||
|
||||
pub fn bump_database_version(&self, new_version: u64) -> Result<()> {
|
||||
self.db.bump_database_version(new_version)
|
||||
}
|
||||
pub fn bump_database_version(&self, new_version: u64) -> Result<()> { self.db.bump_database_version(new_version) }
|
||||
|
||||
pub fn get_media_folder(&self) -> PathBuf {
|
||||
let mut r = PathBuf::new();
|
||||
|
@ -540,20 +425,23 @@ impl Service<'_> {
|
|||
r
|
||||
}
|
||||
|
||||
/// new SHA256 file name media function, requires "sha256_media" feature flag enabled and database migrated
|
||||
/// uses SHA256 hash of the base64 key as the file name
|
||||
/// new SHA256 file name media function, requires "sha256_media" feature
|
||||
/// flag enabled and database migrated uses SHA256 hash of the base64 key as
|
||||
/// the file name
|
||||
pub fn get_media_file_new(&self, key: &[u8]) -> PathBuf {
|
||||
let mut r = PathBuf::new();
|
||||
r.push(self.config.database_path.clone());
|
||||
r.push("media");
|
||||
// Using the hash of the base64 key as the filename
|
||||
// This is to prevent the total length of the path from exceeding the maximum length in most filesystems
|
||||
// This is to prevent the total length of the path from exceeding the maximum
|
||||
// length in most filesystems
|
||||
r.push(general_purpose::URL_SAFE_NO_PAD.encode(sha2::Sha256::digest(key)));
|
||||
r
|
||||
}
|
||||
|
||||
/// old base64 file name media function
|
||||
/// This is the old version of `get_media_file` that uses the full base64 key as the filename.
|
||||
/// This is the old version of `get_media_file` that uses the full base64
|
||||
/// key as the filename.
|
||||
///
|
||||
/// This is deprecated and will be removed in a future release.
|
||||
/// Please use `get_media_file_new` instead.
|
||||
|
@ -566,17 +454,11 @@ impl Service<'_> {
|
|||
r
|
||||
}
|
||||
|
||||
pub fn well_known_client(&self) -> &Option<String> {
|
||||
&self.config.well_known_client
|
||||
}
|
||||
pub fn well_known_client(&self) -> &Option<String> { &self.config.well_known_client }
|
||||
|
||||
pub fn well_known_server(&self) -> &Option<String> {
|
||||
&self.config.well_known_server
|
||||
}
|
||||
pub fn well_known_server(&self) -> &Option<String> { &self.config.well_known_server }
|
||||
|
||||
pub fn unix_socket_path(&self) -> &Option<PathBuf> {
|
||||
&self.config.unix_socket_path
|
||||
}
|
||||
pub fn unix_socket_path(&self) -> &Option<PathBuf> { &self.config.unix_socket_path }
|
||||
|
||||
pub fn shutdown(&self) {
|
||||
self.shutdown.store(true, atomic::Ordering::Relaxed);
|
||||
|
@ -586,7 +468,7 @@ impl Service<'_> {
|
|||
match &self.unix_socket_path() {
|
||||
Some(path) => {
|
||||
std::fs::remove_file(path).unwrap();
|
||||
}
|
||||
},
|
||||
None => error!(
|
||||
"Unable to remove socket file at {:?} during shutdown.",
|
||||
&self.unix_socket_path()
|
||||
|
@ -613,11 +495,7 @@ fn reqwest_client_builder(config: &Config) -> Result<reqwest::ClientBuilder> {
|
|||
.connect_timeout(Duration::from_secs(60))
|
||||
.timeout(Duration::from_secs(60 * 5))
|
||||
.redirect(redirect_policy)
|
||||
.user_agent(concat!(
|
||||
env!("CARGO_PKG_NAME"),
|
||||
"/",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
));
|
||||
.user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")));
|
||||
|
||||
if let Some(proxy) = config.proxy.to_proxy()? {
|
||||
reqwest_client_builder = reqwest_client_builder.proxy(proxy);
|
||||
|
@ -627,8 +505,10 @@ fn reqwest_client_builder(config: &Config) -> Result<reqwest::ClientBuilder> {
|
|||
}
|
||||
|
||||
fn url_preview_reqwest_client_builder(config: &Config) -> Result<reqwest::ClientBuilder> {
|
||||
// for security reasons (e.g. malicious open redirect), we do not want to follow too many redirects when generating URL previews.
|
||||
// let's keep it at least 2 to account for HTTP -> HTTPS upgrades, if it becomes an issue we can consider raising it to 3.
|
||||
// for security reasons (e.g. malicious open redirect), we do not want to follow
|
||||
// too many redirects when generating URL previews. let's keep it at least 2 to
|
||||
// account for HTTP -> HTTPS upgrades, if it becomes an issue we can consider
|
||||
// raising it to 3.
|
||||
let redirect_policy = reqwest::redirect::Policy::custom(|attempt| {
|
||||
if attempt.previous().len() > 2 {
|
||||
attempt.error("Too many redirects (max is 2)")
|
||||
|
@ -642,11 +522,7 @@ fn url_preview_reqwest_client_builder(config: &Config) -> Result<reqwest::Client
|
|||
.connect_timeout(Duration::from_secs(60))
|
||||
.timeout(Duration::from_secs(60 * 5))
|
||||
.redirect(redirect_policy)
|
||||
.user_agent(concat!(
|
||||
env!("CARGO_PKG_NAME"),
|
||||
"/",
|
||||
env!("CARGO_PKG_VERSION")
|
||||
));
|
||||
.user_agent(concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION")));
|
||||
|
||||
if let Some(proxy) = config.proxy.to_proxy()? {
|
||||
reqwest_client_builder = reqwest_client_builder.proxy(proxy);
|
||||
|
|
|
@ -1,78 +1,47 @@
|
|||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
serde::Raw,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String>;
|
||||
fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
|
||||
|
||||
fn delete_backup(&self, user_id: &UserId, version: &str) -> Result<()>;
|
||||
|
||||
fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String>;
|
||||
fn update_backup(&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String>;
|
||||
|
||||
fn get_latest_backup_version(&self, user_id: &UserId) -> Result<Option<String>>;
|
||||
|
||||
fn get_latest_backup(&self, user_id: &UserId)
|
||||
-> Result<Option<(String, Raw<BackupAlgorithm>)>>;
|
||||
fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>>;
|
||||
|
||||
fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>>;
|
||||
|
||||
fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()>;
|
||||
|
||||
fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize>;
|
||||
|
||||
fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String>;
|
||||
|
||||
fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>>;
|
||||
fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>>;
|
||||
|
||||
fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>>;
|
||||
|
||||
fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>>;
|
||||
|
||||
fn delete_all_keys(&self, user_id: &UserId, version: &str) -> Result<()>;
|
||||
|
||||
fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()>;
|
||||
|
||||
fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<()>;
|
||||
fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()>;
|
||||
}
|
||||
|
|
|
@ -1,24 +1,21 @@
|
|||
mod data;
|
||||
pub(crate) use data::Data;
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::Result;
|
||||
pub(crate) use data::Data;
|
||||
use ruma::{
|
||||
api::client::backup::{BackupAlgorithm, KeyBackupData, RoomKeyBackup},
|
||||
serde::Raw,
|
||||
OwnedRoomId, RoomId, UserId,
|
||||
};
|
||||
use std::collections::BTreeMap;
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
||||
impl Service {
|
||||
pub fn create_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
pub fn create_backup(&self, user_id: &UserId, backup_metadata: &Raw<BackupAlgorithm>) -> Result<String> {
|
||||
self.db.create_backup(user_id, backup_metadata)
|
||||
}
|
||||
|
||||
|
@ -27,10 +24,7 @@ impl Service {
|
|||
}
|
||||
|
||||
pub fn update_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
backup_metadata: &Raw<BackupAlgorithm>,
|
||||
&self, user_id: &UserId, version: &str, backup_metadata: &Raw<BackupAlgorithm>,
|
||||
) -> Result<String> {
|
||||
self.db.update_backup(user_id, version, backup_metadata)
|
||||
}
|
||||
|
@ -39,64 +33,36 @@ impl Service {
|
|||
self.db.get_latest_backup_version(user_id)
|
||||
}
|
||||
|
||||
pub fn get_latest_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
pub fn get_latest_backup(&self, user_id: &UserId) -> Result<Option<(String, Raw<BackupAlgorithm>)>> {
|
||||
self.db.get_latest_backup(user_id)
|
||||
}
|
||||
|
||||
pub fn get_backup(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
pub fn get_backup(&self, user_id: &UserId, version: &str) -> Result<Option<Raw<BackupAlgorithm>>> {
|
||||
self.db.get_backup(user_id, version)
|
||||
}
|
||||
|
||||
pub fn add_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
key_data: &Raw<KeyBackupData>,
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str, key_data: &Raw<KeyBackupData>,
|
||||
) -> Result<()> {
|
||||
self.db
|
||||
.add_key(user_id, version, room_id, session_id, key_data)
|
||||
self.db.add_key(user_id, version, room_id, session_id, key_data)
|
||||
}
|
||||
|
||||
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> {
|
||||
self.db.count_keys(user_id, version)
|
||||
}
|
||||
pub fn count_keys(&self, user_id: &UserId, version: &str) -> Result<usize> { self.db.count_keys(user_id, version) }
|
||||
|
||||
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> {
|
||||
self.db.get_etag(user_id, version)
|
||||
}
|
||||
pub fn get_etag(&self, user_id: &UserId, version: &str) -> Result<String> { self.db.get_etag(user_id, version) }
|
||||
|
||||
pub fn get_all(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||
pub fn get_all(&self, user_id: &UserId, version: &str) -> Result<BTreeMap<OwnedRoomId, RoomKeyBackup>> {
|
||||
self.db.get_all(user_id, version)
|
||||
}
|
||||
|
||||
pub fn get_room(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId,
|
||||
) -> Result<BTreeMap<String, Raw<KeyBackupData>>> {
|
||||
self.db.get_room(user_id, version, room_id)
|
||||
}
|
||||
|
||||
pub fn get_session(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str,
|
||||
) -> Result<Option<Raw<KeyBackupData>>> {
|
||||
self.db.get_session(user_id, version, room_id, session_id)
|
||||
}
|
||||
|
@ -105,23 +71,11 @@ impl Service {
|
|||
self.db.delete_all_keys(user_id, version)
|
||||
}
|
||||
|
||||
pub fn delete_room_keys(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
) -> Result<()> {
|
||||
pub fn delete_room_keys(&self, user_id: &UserId, version: &str, room_id: &RoomId) -> Result<()> {
|
||||
self.db.delete_room_keys(user_id, version, room_id)
|
||||
}
|
||||
|
||||
pub fn delete_room_key(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
version: &str,
|
||||
room_id: &RoomId,
|
||||
session_id: &str,
|
||||
) -> Result<()> {
|
||||
self.db
|
||||
.delete_room_key(user_id, version, room_id, session_id)
|
||||
pub fn delete_room_key(&self, user_id: &UserId, version: &str, room_id: &RoomId, session_id: &str) -> Result<()> {
|
||||
self.db.delete_room_key(user_id, version, room_id, session_id)
|
||||
}
|
||||
}
|
||||
|
|
|
@ -2,22 +2,14 @@ use crate::Result;
|
|||
|
||||
pub trait Data: Send + Sync {
|
||||
fn create_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>>;
|
||||
|
||||
fn delete_file_mxc(&self, mxc: String) -> Result<()>;
|
||||
|
||||
/// Returns content_disposition, content_type and the metadata key.
|
||||
fn search_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
&self, mxc: String, width: u32, height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)>;
|
||||
|
||||
fn search_mxc_metadata_prefix(&self, mxc: String) -> Result<Vec<Vec<u8>>>;
|
||||
|
@ -26,12 +18,7 @@ pub trait Data: Send + Sync {
|
|||
|
||||
fn remove_url_preview(&self, url: &str) -> Result<()>;
|
||||
|
||||
fn set_url_preview(
|
||||
&self,
|
||||
url: &str,
|
||||
data: &super::UrlPreviewData,
|
||||
timestamp: std::time::Duration,
|
||||
) -> Result<()>;
|
||||
fn set_url_preview(&self, url: &str, data: &super::UrlPreviewData, timestamp: std::time::Duration) -> Result<()>;
|
||||
|
||||
fn get_url_preview(&self, url: &str) -> Option<super::UrlPreviewData>;
|
||||
}
|
||||
|
|
|
@ -7,18 +7,17 @@ use std::{
|
|||
};
|
||||
|
||||
pub(crate) use data::Data;
|
||||
use image::imageops::FilterType;
|
||||
use ruma::OwnedMxcUri;
|
||||
use serde::Serialize;
|
||||
use tracing::{debug, error};
|
||||
|
||||
use crate::{services, utils, Error, Result};
|
||||
use image::imageops::FilterType;
|
||||
|
||||
use tokio::{
|
||||
fs::{self, File},
|
||||
io::{AsyncReadExt, AsyncWriteExt, BufReader},
|
||||
sync::Mutex,
|
||||
};
|
||||
use tracing::{debug, error};
|
||||
|
||||
use crate::{services, utils, Error, Result};
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct FileMeta {
|
||||
|
@ -29,35 +28,17 @@ pub struct FileMeta {
|
|||
|
||||
#[derive(Serialize, Default)]
|
||||
pub struct UrlPreviewData {
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "og:title")
|
||||
)]
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:title"))]
|
||||
pub title: Option<String>,
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "og:description")
|
||||
)]
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:description"))]
|
||||
pub description: Option<String>,
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "og:image")
|
||||
)]
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image"))]
|
||||
pub image: Option<String>,
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "matrix:image:size")
|
||||
)]
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "matrix:image:size"))]
|
||||
pub image_size: Option<usize>,
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "og:image:width")
|
||||
)]
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:width"))]
|
||||
pub image_width: Option<u32>,
|
||||
#[serde(
|
||||
skip_serializing_if = "Option::is_none",
|
||||
rename(serialize = "og:image:height")
|
||||
)]
|
||||
#[serde(skip_serializing_if = "Option::is_none", rename(serialize = "og:image:height"))]
|
||||
pub image_height: Option<u32>,
|
||||
}
|
||||
|
||||
|
@ -69,16 +50,10 @@ pub struct Service {
|
|||
impl Service {
|
||||
/// Uploads a file.
|
||||
pub async fn create(
|
||||
&self,
|
||||
mxc: String,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
file: &[u8],
|
||||
&self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, file: &[u8],
|
||||
) -> Result<()> {
|
||||
// Width, Height = 0 if it's not a thumbnail
|
||||
let key = self
|
||||
.db
|
||||
.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
|
||||
let key = self.db.create_file_metadata(mxc, 0, 0, content_disposition, content_type)?;
|
||||
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
|
@ -104,10 +79,7 @@ impl Service {
|
|||
};
|
||||
debug!("Got local file path: {:?}", file_path);
|
||||
|
||||
debug!(
|
||||
"Deleting local file {:?} from filesystem, original MXC: {}",
|
||||
file_path, mxc
|
||||
);
|
||||
debug!("Deleting local file {:?} from filesystem, original MXC: {}", file_path, mxc);
|
||||
tokio::fs::remove_file(file_path).await?;
|
||||
|
||||
debug!("Deleting MXC {mxc} from database");
|
||||
|
@ -117,23 +89,18 @@ impl Service {
|
|||
Ok(())
|
||||
} else {
|
||||
error!("Failed to find any media keys for MXC \"{mxc}\" in our database (MXC does not exist)");
|
||||
Err(Error::bad_database("Failed to find any media keys for the provided MXC in our database (MXC does not exist)"))
|
||||
Err(Error::bad_database(
|
||||
"Failed to find any media keys for the provided MXC in our database (MXC does not exist)",
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
/// Uploads or replaces a file thumbnail.
|
||||
pub async fn upload_thumbnail(
|
||||
&self,
|
||||
mxc: String,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
width: u32,
|
||||
height: u32,
|
||||
&self, mxc: String, content_disposition: Option<&str>, content_type: Option<&str>, width: u32, height: u32,
|
||||
file: &[u8],
|
||||
) -> Result<()> {
|
||||
let key =
|
||||
self.db
|
||||
.create_file_metadata(mxc, width, height, content_disposition, content_type)?;
|
||||
let key = self.db.create_file_metadata(mxc, width, height, content_disposition, content_type)?;
|
||||
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
|
@ -150,9 +117,7 @@ impl Service {
|
|||
|
||||
/// Downloads a file.
|
||||
pub async fn get(&self, mxc: String) -> Result<Option<FileMeta>> {
|
||||
if let Ok((content_disposition, content_type, key)) =
|
||||
self.db.search_file_metadata(mxc, 0, 0)
|
||||
{
|
||||
if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc, 0, 0) {
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
} else {
|
||||
|
@ -161,9 +126,7 @@ impl Service {
|
|||
};
|
||||
|
||||
let mut file = Vec::new();
|
||||
BufReader::new(File::open(path).await?)
|
||||
.read_to_end(&mut file)
|
||||
.await?;
|
||||
BufReader::new(File::open(path).await?).read_to_end(&mut file).await?;
|
||||
|
||||
Ok(Some(FileMeta {
|
||||
content_disposition,
|
||||
|
@ -175,8 +138,8 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
/// Deletes all remote only media files in the given at or after time/duration. Returns a u32
|
||||
/// with the amount of media files deleted.
|
||||
/// Deletes all remote only media files in the given at or after
|
||||
/// time/duration. Returns a u32 with the amount of media files deleted.
|
||||
pub async fn delete_all_remote_media_at_after_time(&self, time: String) -> Result<u32> {
|
||||
if let Ok(all_keys) = self.db.get_all_media_keys() {
|
||||
let user_duration: SystemTime = match cyborgtime::parse_duration(&time) {
|
||||
|
@ -184,13 +147,11 @@ impl Service {
|
|||
debug!("Parsed duration: {:?}", duration);
|
||||
debug!("System time now: {:?}", SystemTime::now());
|
||||
SystemTime::now() - duration
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
error!("Failed to parse user-specified time duration: {}", e);
|
||||
return Err(Error::bad_database(
|
||||
"Failed to parse user-specified time duration.",
|
||||
));
|
||||
}
|
||||
return Err(Error::bad_database("Failed to parse user-specified time duration."));
|
||||
},
|
||||
};
|
||||
|
||||
let mut remote_mxcs: Vec<String> = vec![];
|
||||
|
@ -198,17 +159,16 @@ impl Service {
|
|||
for key in all_keys {
|
||||
debug!("Full MXC key from database: {:?}", key);
|
||||
|
||||
// we need to get the MXC URL from the first part of the key (the first 0xff / 255 push)
|
||||
// this code does look kinda crazy but blame conduit for using magic keys
|
||||
let mut parts = key.split(|&b| b == 0xff);
|
||||
// we need to get the MXC URL from the first part of the key (the first 0xff /
|
||||
// 255 push) this code does look kinda crazy but blame conduit for using magic
|
||||
// keys
|
||||
let mut parts = key.split(|&b| b == 0xFF);
|
||||
let mxc = parts
|
||||
.next()
|
||||
.map(|bytes| {
|
||||
utils::string_from_bytes(bytes).map_err(|e| {
|
||||
error!("Failed to parse MXC unicode bytes from our database: {}", e);
|
||||
Error::bad_database(
|
||||
"Failed to parse MXC unicode bytes from our database",
|
||||
)
|
||||
Error::bad_database("Failed to parse MXC unicode bytes from our database")
|
||||
})
|
||||
})
|
||||
.transpose()?;
|
||||
|
@ -219,7 +179,7 @@ impl Service {
|
|||
return Err(Error::bad_database(
|
||||
"Parsed MXC URL unicode bytes from database but still is None",
|
||||
));
|
||||
}
|
||||
},
|
||||
};
|
||||
|
||||
debug!("Parsed MXC key to URL: {}", mxc_s);
|
||||
|
@ -252,12 +212,13 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
debug!("Finished going through all our media in database for eligible keys to delete, checking if these are empty");
|
||||
debug!(
|
||||
"Finished going through all our media in database for eligible keys to delete, checking if these are \
|
||||
empty"
|
||||
);
|
||||
|
||||
if remote_mxcs.is_empty() {
|
||||
return Err(Error::bad_database(
|
||||
"Did not found any eligible MXCs to delete.",
|
||||
));
|
||||
return Err(Error::bad_database("Did not found any eligible MXCs to delete."));
|
||||
}
|
||||
|
||||
debug!("Deleting media now in the past \"{:?}\".", user_duration);
|
||||
|
@ -278,8 +239,8 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
/// Returns width, height of the thumbnail and whether it should be cropped. Returns None when
|
||||
/// the server should send the original file.
|
||||
/// Returns width, height of the thumbnail and whether it should be cropped.
|
||||
/// Returns None when the server should send the original file.
|
||||
pub fn thumbnail_properties(&self, width: u32, height: u32) -> Option<(u32, u32, bool)> {
|
||||
match (width, height) {
|
||||
(0..=32, 0..=32) => Some((32, 32, true)),
|
||||
|
@ -296,24 +257,18 @@ impl Service {
|
|||
/// Here's an example on how it works:
|
||||
///
|
||||
/// - Client requests an image with width=567, height=567
|
||||
/// - Server rounds that up to (800, 600), so it doesn't have to save too many thumbnails
|
||||
/// - Server rounds that up again to (958, 600) to fix the aspect ratio (only for width,height>96)
|
||||
/// - Server rounds that up to (800, 600), so it doesn't have to save too
|
||||
/// many thumbnails
|
||||
/// - Server rounds that up again to (958, 600) to fix the aspect ratio
|
||||
/// (only for width,height>96)
|
||||
/// - Server creates the thumbnail and sends it to the user
|
||||
///
|
||||
/// For width,height <= 96 the server uses another thumbnailing algorithm which crops the image afterwards.
|
||||
pub async fn get_thumbnail(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
) -> Result<Option<FileMeta>> {
|
||||
let (width, height, crop) = self
|
||||
.thumbnail_properties(width, height)
|
||||
.unwrap_or((0, 0, false)); // 0, 0 because that's the original file
|
||||
/// For width,height <= 96 the server uses another thumbnailing algorithm
|
||||
/// which crops the image afterwards.
|
||||
pub async fn get_thumbnail(&self, mxc: String, width: u32, height: u32) -> Result<Option<FileMeta>> {
|
||||
let (width, height, crop) = self.thumbnail_properties(width, height).unwrap_or((0, 0, false)); // 0, 0 because that's the original file
|
||||
|
||||
if let Ok((content_disposition, content_type, key)) =
|
||||
self.db.search_file_metadata(mxc.clone(), width, height)
|
||||
{
|
||||
if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), width, height) {
|
||||
// Using saved thumbnail
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
|
@ -330,9 +285,7 @@ impl Service {
|
|||
content_type,
|
||||
file: file.clone(),
|
||||
}))
|
||||
} else if let Ok((content_disposition, content_type, key)) =
|
||||
self.db.search_file_metadata(mxc.clone(), 0, 0)
|
||||
{
|
||||
} else if let Ok((content_disposition, content_type, key)) = self.db.search_file_metadata(mxc.clone(), 0, 0) {
|
||||
// Generate a thumbnail
|
||||
let path = if cfg!(feature = "sha256_media") {
|
||||
services().globals.get_media_file_new(&key)
|
||||
|
@ -365,19 +318,16 @@ impl Service {
|
|||
|
||||
let use_width = nratio <= ratio;
|
||||
let intermediate = if use_width {
|
||||
u64::from(original_height) * u64::from(width)
|
||||
/ u64::from(original_width)
|
||||
u64::from(original_height) * u64::from(width) / u64::from(original_width)
|
||||
} else {
|
||||
u64::from(original_width) * u64::from(height)
|
||||
/ u64::from(original_height)
|
||||
u64::from(original_width) * u64::from(height) / u64::from(original_height)
|
||||
};
|
||||
if use_width {
|
||||
if intermediate <= u64::from(::std::u32::MAX) {
|
||||
(width, intermediate as u32)
|
||||
} else {
|
||||
(
|
||||
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate)
|
||||
as u32,
|
||||
(u64::from(width) * u64::from(::std::u32::MAX) / intermediate) as u32,
|
||||
::std::u32::MAX,
|
||||
)
|
||||
}
|
||||
|
@ -386,8 +336,7 @@ impl Service {
|
|||
} else {
|
||||
(
|
||||
::std::u32::MAX,
|
||||
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate)
|
||||
as u32,
|
||||
(u64::from(height) * u64::from(::std::u32::MAX) / intermediate) as u32,
|
||||
)
|
||||
}
|
||||
};
|
||||
|
@ -396,10 +345,7 @@ impl Service {
|
|||
};
|
||||
|
||||
let mut thumbnail_bytes = Vec::new();
|
||||
thumbnail.write_to(
|
||||
&mut Cursor::new(&mut thumbnail_bytes),
|
||||
image::ImageOutputFormat::Png,
|
||||
)?;
|
||||
thumbnail.write_to(&mut Cursor::new(&mut thumbnail_bytes), image::ImageOutputFormat::Png)?;
|
||||
|
||||
// Save thumbnail in database so we don't have to generate it again next time
|
||||
let thumbnail_key = self.db.create_file_metadata(
|
||||
|
@ -438,9 +384,7 @@ impl Service {
|
|||
}
|
||||
}
|
||||
|
||||
pub async fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> {
|
||||
self.db.get_url_preview(url)
|
||||
}
|
||||
pub async fn get_url_preview(&self, url: &str) -> Option<UrlPreviewData> { self.db.get_url_preview(url) }
|
||||
|
||||
pub async fn remove_url_preview(&self, url: &str) -> Result<()> {
|
||||
// TODO: also remove the downloaded image
|
||||
|
@ -448,9 +392,7 @@ impl Service {
|
|||
}
|
||||
|
||||
pub async fn set_url_preview(&self, url: &str, data: &UrlPreviewData) -> Result<()> {
|
||||
let now = SystemTime::now()
|
||||
.duration_since(SystemTime::UNIX_EPOCH)
|
||||
.expect("valid system time");
|
||||
let now = SystemTime::now().duration_since(SystemTime::UNIX_EPOCH).expect("valid system time");
|
||||
self.db.set_url_preview(url, data, now)
|
||||
}
|
||||
}
|
||||
|
@ -459,9 +401,8 @@ impl Service {
|
|||
mod tests {
|
||||
use std::path::PathBuf;
|
||||
|
||||
use sha2::Digest;
|
||||
|
||||
use base64::{engine::general_purpose, Engine as _};
|
||||
use sha2::Digest;
|
||||
|
||||
use super::*;
|
||||
|
||||
|
@ -469,73 +410,40 @@ mod tests {
|
|||
|
||||
impl Data for MockedKVDatabase {
|
||||
fn create_file_metadata(
|
||||
&self,
|
||||
mxc: String,
|
||||
width: u32,
|
||||
height: u32,
|
||||
content_disposition: Option<&str>,
|
||||
content_type: Option<&str>,
|
||||
&self, mxc: String, width: u32, height: u32, content_disposition: Option<&str>, content_type: Option<&str>,
|
||||
) -> Result<Vec<u8>> {
|
||||
// copied from src/database/key_value/media.rs
|
||||
let mut key = mxc.as_bytes().to_vec();
|
||||
key.push(0xff);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(&width.to_be_bytes());
|
||||
key.extend_from_slice(&height.to_be_bytes());
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_disposition
|
||||
.as_ref()
|
||||
.map(|f| f.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xff);
|
||||
key.extend_from_slice(
|
||||
content_type
|
||||
.as_ref()
|
||||
.map(|c| c.as_bytes())
|
||||
.unwrap_or_default(),
|
||||
);
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(content_disposition.as_ref().map(|f| f.as_bytes()).unwrap_or_default());
|
||||
key.push(0xFF);
|
||||
key.extend_from_slice(content_type.as_ref().map(|c| c.as_bytes()).unwrap_or_default());
|
||||
|
||||
Ok(key)
|
||||
}
|
||||
|
||||
fn delete_file_mxc(&self, _mxc: String) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
fn delete_file_mxc(&self, _mxc: String) -> Result<()> { todo!() }
|
||||
|
||||
fn search_mxc_metadata_prefix(&self, _mxc: String) -> Result<Vec<Vec<u8>>> {
|
||||
todo!()
|
||||
}
|
||||
fn search_mxc_metadata_prefix(&self, _mxc: String) -> Result<Vec<Vec<u8>>> { todo!() }
|
||||
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> {
|
||||
todo!()
|
||||
}
|
||||
fn get_all_media_keys(&self) -> Result<Vec<Vec<u8>>> { todo!() }
|
||||
|
||||
fn search_file_metadata(
|
||||
&self,
|
||||
_mxc: String,
|
||||
_width: u32,
|
||||
_height: u32,
|
||||
&self, _mxc: String, _width: u32, _height: u32,
|
||||
) -> Result<(Option<String>, Option<String>, Vec<u8>)> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn remove_url_preview(&self, _url: &str) -> Result<()> {
|
||||
fn remove_url_preview(&self, _url: &str) -> Result<()> { todo!() }
|
||||
|
||||
fn set_url_preview(&self, _url: &str, _data: &UrlPreviewData, _timestamp: std::time::Duration) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn set_url_preview(
|
||||
&self,
|
||||
_url: &str,
|
||||
_data: &UrlPreviewData,
|
||||
_timestamp: std::time::Duration,
|
||||
) -> Result<()> {
|
||||
todo!()
|
||||
}
|
||||
|
||||
fn get_url_preview(&self, _url: &str) -> Option<UrlPreviewData> {
|
||||
todo!()
|
||||
}
|
||||
fn get_url_preview(&self, _url: &str) -> Option<UrlPreviewData> { todo!() }
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
@ -549,18 +457,11 @@ mod tests {
|
|||
let mxc = "mxc://example.com/ascERGshawAWawugaAcauga".to_owned();
|
||||
let width = 100;
|
||||
let height = 100;
|
||||
let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special characters like äöüß and even emoji like 🦀.png\"";
|
||||
let content_disposition = "attachment; filename=\"this is a very long file name with spaces and special \
|
||||
characters like äöüß and even emoji like 🦀.png\"";
|
||||
let content_type = "image/png";
|
||||
let key = media
|
||||
.db
|
||||
.create_file_metadata(
|
||||
mxc,
|
||||
width,
|
||||
height,
|
||||
Some(content_disposition),
|
||||
Some(content_type),
|
||||
)
|
||||
.unwrap();
|
||||
let key =
|
||||
media.db.create_file_metadata(mxc, width, height, Some(content_disposition), Some(content_type)).unwrap();
|
||||
let mut r = PathBuf::new();
|
||||
r.push("/tmp");
|
||||
r.push("media");
|
||||
|
|
|
@ -51,32 +51,59 @@ impl Services<'_> {
|
|||
+ sending::Data
|
||||
+ 'static,
|
||||
>(
|
||||
db: &'static D,
|
||||
config: Config,
|
||||
db: &'static D, config: Config,
|
||||
) -> Result<Self> {
|
||||
Ok(Self {
|
||||
appservice: appservice::Service { db },
|
||||
pusher: pusher::Service { db },
|
||||
appservice: appservice::Service {
|
||||
db,
|
||||
},
|
||||
pusher: pusher::Service {
|
||||
db,
|
||||
},
|
||||
rooms: rooms::Service {
|
||||
alias: rooms::alias::Service { db },
|
||||
auth_chain: rooms::auth_chain::Service { db },
|
||||
directory: rooms::directory::Service { db },
|
||||
alias: rooms::alias::Service {
|
||||
db,
|
||||
},
|
||||
auth_chain: rooms::auth_chain::Service {
|
||||
db,
|
||||
},
|
||||
directory: rooms::directory::Service {
|
||||
db,
|
||||
},
|
||||
edus: rooms::edus::Service {
|
||||
presence: rooms::edus::presence::Service { db },
|
||||
read_receipt: rooms::edus::read_receipt::Service { db },
|
||||
typing: rooms::edus::typing::Service { db },
|
||||
presence: rooms::edus::presence::Service {
|
||||
db,
|
||||
},
|
||||
read_receipt: rooms::edus::read_receipt::Service {
|
||||
db,
|
||||
},
|
||||
typing: rooms::edus::typing::Service {
|
||||
db,
|
||||
},
|
||||
},
|
||||
event_handler: rooms::event_handler::Service,
|
||||
lazy_loading: rooms::lazy_loading::Service {
|
||||
db,
|
||||
lazy_load_waiting: Mutex::new(HashMap::new()),
|
||||
},
|
||||
metadata: rooms::metadata::Service { db },
|
||||
outlier: rooms::outlier::Service { db },
|
||||
pdu_metadata: rooms::pdu_metadata::Service { db },
|
||||
search: rooms::search::Service { db },
|
||||
short: rooms::short::Service { db },
|
||||
state: rooms::state::Service { db },
|
||||
metadata: rooms::metadata::Service {
|
||||
db,
|
||||
},
|
||||
outlier: rooms::outlier::Service {
|
||||
db,
|
||||
},
|
||||
pdu_metadata: rooms::pdu_metadata::Service {
|
||||
db,
|
||||
},
|
||||
search: rooms::search::Service {
|
||||
db,
|
||||
},
|
||||
short: rooms::short::Service {
|
||||
db,
|
||||
},
|
||||
state: rooms::state::Service {
|
||||
db,
|
||||
},
|
||||
state_accessor: rooms::state_accessor::Service {
|
||||
db,
|
||||
server_visibility_cache: Mutex::new(LruCache::new(
|
||||
|
@ -86,7 +113,9 @@ impl Services<'_> {
|
|||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
state_cache: rooms::state_cache::Service { db },
|
||||
state_cache: rooms::state_cache::Service {
|
||||
db,
|
||||
},
|
||||
state_compressor: rooms::state_compressor::Service {
|
||||
db,
|
||||
stateinfo_cache: Mutex::new(LruCache::new(
|
||||
|
@ -97,23 +126,35 @@ impl Services<'_> {
|
|||
db,
|
||||
lasttimelinecount_cache: Mutex::new(HashMap::new()),
|
||||
},
|
||||
threads: rooms::threads::Service { db },
|
||||
threads: rooms::threads::Service {
|
||||
db,
|
||||
},
|
||||
spaces: rooms::spaces::Service {
|
||||
roomid_spacechunk_cache: Mutex::new(LruCache::new(
|
||||
(100.0 * config.conduit_cache_capacity_modifier) as usize,
|
||||
)),
|
||||
},
|
||||
user: rooms::user::Service { db },
|
||||
user: rooms::user::Service {
|
||||
db,
|
||||
},
|
||||
},
|
||||
transaction_ids: transaction_ids::Service {
|
||||
db,
|
||||
},
|
||||
uiaa: uiaa::Service {
|
||||
db,
|
||||
},
|
||||
transaction_ids: transaction_ids::Service { db },
|
||||
uiaa: uiaa::Service { db },
|
||||
users: users::Service {
|
||||
db,
|
||||
connections: Mutex::new(BTreeMap::new()),
|
||||
},
|
||||
account_data: account_data::Service { db },
|
||||
account_data: account_data::Service {
|
||||
db,
|
||||
},
|
||||
admin: admin::Service::build(),
|
||||
key_backups: key_backups::Service { db },
|
||||
key_backups: key_backups::Service {
|
||||
db,
|
||||
},
|
||||
media: media::Service {
|
||||
db,
|
||||
url_preview_mutex: RwLock::new(HashMap::new()),
|
||||
|
@ -123,49 +164,14 @@ impl Services<'_> {
|
|||
globals: globals::Service::load(db, config)?,
|
||||
})
|
||||
}
|
||||
|
||||
fn memory_usage(&self) -> String {
|
||||
let lazy_load_waiting = self
|
||||
.rooms
|
||||
.lazy_loading
|
||||
.lazy_load_waiting
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let server_visibility_cache = self
|
||||
.rooms
|
||||
.state_accessor
|
||||
.server_visibility_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let user_visibility_cache = self
|
||||
.rooms
|
||||
.state_accessor
|
||||
.user_visibility_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let stateinfo_cache = self
|
||||
.rooms
|
||||
.state_compressor
|
||||
.stateinfo_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let lasttimelinecount_cache = self
|
||||
.rooms
|
||||
.timeline
|
||||
.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let roomid_spacechunk_cache = self
|
||||
.rooms
|
||||
.spaces
|
||||
.roomid_spacechunk_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.len();
|
||||
let lazy_load_waiting = self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().len();
|
||||
let server_visibility_cache = self.rooms.state_accessor.server_visibility_cache.lock().unwrap().len();
|
||||
let user_visibility_cache = self.rooms.state_accessor.user_visibility_cache.lock().unwrap().len();
|
||||
let stateinfo_cache = self.rooms.state_compressor.stateinfo_cache.lock().unwrap().len();
|
||||
let lasttimelinecount_cache = self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().len();
|
||||
let roomid_spacechunk_cache = self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().len();
|
||||
|
||||
format!(
|
||||
"\
|
||||
|
@ -174,58 +180,28 @@ server_visibility_cache: {server_visibility_cache}
|
|||
user_visibility_cache: {user_visibility_cache}
|
||||
stateinfo_cache: {stateinfo_cache}
|
||||
lasttimelinecount_cache: {lasttimelinecount_cache}
|
||||
roomid_spacechunk_cache: {roomid_spacechunk_cache}\
|
||||
"
|
||||
roomid_spacechunk_cache: {roomid_spacechunk_cache}"
|
||||
)
|
||||
}
|
||||
|
||||
fn clear_caches(&self, amount: u32) {
|
||||
if amount > 0 {
|
||||
self.rooms
|
||||
.lazy_loading
|
||||
.lazy_load_waiting
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
self.rooms.lazy_loading.lazy_load_waiting.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 1 {
|
||||
self.rooms
|
||||
.state_accessor
|
||||
.server_visibility_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
self.rooms.state_accessor.server_visibility_cache.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 2 {
|
||||
self.rooms
|
||||
.state_accessor
|
||||
.user_visibility_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
self.rooms.state_accessor.user_visibility_cache.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 3 {
|
||||
self.rooms
|
||||
.state_compressor
|
||||
.stateinfo_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
self.rooms.state_compressor.stateinfo_cache.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 4 {
|
||||
self.rooms
|
||||
.timeline
|
||||
.lasttimelinecount_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
self.rooms.timeline.lasttimelinecount_cache.lock().unwrap().clear();
|
||||
}
|
||||
if amount > 5 {
|
||||
self.rooms
|
||||
.spaces
|
||||
.roomid_spacechunk_cache
|
||||
.lock()
|
||||
.unwrap()
|
||||
.clear();
|
||||
self.rooms.spaces.roomid_spacechunk_cache.lock().unwrap().clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,23 +1,25 @@
|
|||
use crate::Error;
|
||||
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
|
||||
|
||||
use ruma::{
|
||||
canonical_json::redact_content_in_place,
|
||||
events::{
|
||||
room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent,
|
||||
AnyEphemeralRoomEvent, AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent,
|
||||
AnySyncStateEvent, AnySyncTimelineEvent, AnyTimelineEvent, StateEvent, TimelineEventType,
|
||||
room::member::RoomMemberEventContent, space::child::HierarchySpaceChildEvent, AnyEphemeralRoomEvent,
|
||||
AnyMessageLikeEvent, AnyStateEvent, AnyStrippedStateEvent, AnySyncStateEvent, AnySyncTimelineEvent,
|
||||
AnyTimelineEvent, StateEvent, TimelineEventType,
|
||||
},
|
||||
serde::Raw,
|
||||
state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch,
|
||||
OwnedEventId, OwnedRoomId, OwnedUserId, RoomId, RoomVersionId, UInt, UserId,
|
||||
state_res, CanonicalJsonObject, CanonicalJsonValue, EventId, MilliSecondsSinceUnixEpoch, OwnedEventId, OwnedRoomId,
|
||||
OwnedUserId, RoomId, RoomVersionId, UInt, UserId,
|
||||
};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::{
|
||||
json,
|
||||
value::{to_raw_value, RawValue as RawJsonValue},
|
||||
};
|
||||
use std::{cmp::Ordering, collections::BTreeMap, sync::Arc};
|
||||
use tracing::warn;
|
||||
|
||||
use crate::Error;
|
||||
|
||||
/// Content hashes of a PDU.
|
||||
#[derive(Clone, Debug, Deserialize, Serialize)]
|
||||
pub struct EventHash {
|
||||
|
@ -50,11 +52,7 @@ pub struct PduEvent {
|
|||
|
||||
impl PduEvent {
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn redact(
|
||||
&mut self,
|
||||
room_version_id: RoomVersionId,
|
||||
reason: &PduEvent,
|
||||
) -> crate::Result<()> {
|
||||
pub fn redact(&mut self, room_version_id: RoomVersionId, reason: &PduEvent) -> crate::Result<()> {
|
||||
self.unsigned = None;
|
||||
|
||||
let mut content = serde_json::from_str(self.content.get())
|
||||
|
@ -62,9 +60,12 @@ impl PduEvent {
|
|||
redact_content_in_place(&mut content, &room_version_id, self.kind.to_string())
|
||||
.map_err(|e| Error::RedactionError(self.sender.server_name().to_owned(), e))?;
|
||||
|
||||
self.unsigned = Some(to_raw_value(&json!({
|
||||
self.unsigned = Some(
|
||||
to_raw_value(&json!({
|
||||
"redacted_because": serde_json::to_value(reason).expect("to_value(PduEvent) always works")
|
||||
})).expect("to string always works"));
|
||||
}))
|
||||
.expect("to string always works"),
|
||||
);
|
||||
|
||||
self.content = to_raw_value(&content).expect("to string always works");
|
||||
|
||||
|
@ -73,8 +74,7 @@ impl PduEvent {
|
|||
|
||||
pub fn remove_transaction_id(&mut self) -> crate::Result<()> {
|
||||
if let Some(unsigned) = &self.unsigned {
|
||||
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> =
|
||||
serde_json::from_str(unsigned.get())
|
||||
let mut unsigned: BTreeMap<String, Box<RawJsonValue>> = serde_json::from_str(unsigned.get())
|
||||
.map_err(|_| Error::bad_database("Invalid unsigned in pdu event"))?;
|
||||
unsigned.remove("transaction_id");
|
||||
self.unsigned = Some(to_raw_value(&unsigned).expect("unsigned is valid"));
|
||||
|
@ -276,36 +276,25 @@ impl PduEvent {
|
|||
|
||||
/// This does not return a full `Pdu` it is only to satisfy ruma's types.
|
||||
#[tracing::instrument]
|
||||
pub fn convert_to_outgoing_federation_event(
|
||||
mut pdu_json: CanonicalJsonObject,
|
||||
) -> Box<RawJsonValue> {
|
||||
if let Some(unsigned) = pdu_json
|
||||
.get_mut("unsigned")
|
||||
.and_then(|val| val.as_object_mut())
|
||||
{
|
||||
pub fn convert_to_outgoing_federation_event(mut pdu_json: CanonicalJsonObject) -> Box<RawJsonValue> {
|
||||
if let Some(unsigned) = pdu_json.get_mut("unsigned").and_then(|val| val.as_object_mut()) {
|
||||
unsigned.remove("transaction_id");
|
||||
}
|
||||
|
||||
pdu_json.remove("event_id");
|
||||
|
||||
// TODO: another option would be to convert it to a canonical string to validate size
|
||||
// and return a Result<Raw<...>>
|
||||
// TODO: another option would be to convert it to a canonical string to validate
|
||||
// size and return a Result<Raw<...>>
|
||||
// serde_json::from_str::<Raw<_>>(
|
||||
// ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is valid serde_json::Value"),
|
||||
// )
|
||||
// ruma::serde::to_canonical_json_string(pdu_json).expect("CanonicalJson is
|
||||
// valid serde_json::Value"), )
|
||||
// .expect("Raw::from_value always works")
|
||||
|
||||
to_raw_value(&pdu_json).expect("CanonicalJson is valid serde_json::Value")
|
||||
}
|
||||
|
||||
pub fn from_id_val(
|
||||
event_id: &EventId,
|
||||
mut json: CanonicalJsonObject,
|
||||
) -> Result<Self, serde_json::Error> {
|
||||
json.insert(
|
||||
"event_id".to_owned(),
|
||||
CanonicalJsonValue::String(event_id.as_str().to_owned()),
|
||||
);
|
||||
pub fn from_id_val(event_id: &EventId, mut json: CanonicalJsonObject) -> Result<Self, serde_json::Error> {
|
||||
json.insert("event_id".to_owned(), CanonicalJsonValue::String(event_id.as_str().to_owned()));
|
||||
|
||||
serde_json::from_value(serde_json::to_value(json).expect("valid JSON"))
|
||||
}
|
||||
|
@ -314,72 +303,46 @@ impl PduEvent {
|
|||
impl state_res::Event for PduEvent {
|
||||
type Id = Arc<EventId>;
|
||||
|
||||
fn event_id(&self) -> &Self::Id {
|
||||
&self.event_id
|
||||
}
|
||||
fn event_id(&self) -> &Self::Id { &self.event_id }
|
||||
|
||||
fn room_id(&self) -> &RoomId {
|
||||
&self.room_id
|
||||
}
|
||||
fn room_id(&self) -> &RoomId { &self.room_id }
|
||||
|
||||
fn sender(&self) -> &UserId {
|
||||
&self.sender
|
||||
}
|
||||
fn sender(&self) -> &UserId { &self.sender }
|
||||
|
||||
fn event_type(&self) -> &TimelineEventType {
|
||||
&self.kind
|
||||
}
|
||||
fn event_type(&self) -> &TimelineEventType { &self.kind }
|
||||
|
||||
fn content(&self) -> &RawJsonValue {
|
||||
&self.content
|
||||
}
|
||||
fn content(&self) -> &RawJsonValue { &self.content }
|
||||
|
||||
fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch {
|
||||
MilliSecondsSinceUnixEpoch(self.origin_server_ts)
|
||||
}
|
||||
fn origin_server_ts(&self) -> MilliSecondsSinceUnixEpoch { MilliSecondsSinceUnixEpoch(self.origin_server_ts) }
|
||||
|
||||
fn state_key(&self) -> Option<&str> {
|
||||
self.state_key.as_deref()
|
||||
}
|
||||
fn state_key(&self) -> Option<&str> { self.state_key.as_deref() }
|
||||
|
||||
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
|
||||
Box::new(self.prev_events.iter())
|
||||
}
|
||||
fn prev_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.prev_events.iter()) }
|
||||
|
||||
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> {
|
||||
Box::new(self.auth_events.iter())
|
||||
}
|
||||
fn auth_events(&self) -> Box<dyn DoubleEndedIterator<Item = &Self::Id> + '_> { Box::new(self.auth_events.iter()) }
|
||||
|
||||
fn redacts(&self) -> Option<&Self::Id> {
|
||||
self.redacts.as_ref()
|
||||
}
|
||||
fn redacts(&self) -> Option<&Self::Id> { self.redacts.as_ref() }
|
||||
}
|
||||
|
||||
// These impl's allow us to dedup state snapshots when resolving state
|
||||
// for incoming events (federation/send/{txn}).
|
||||
impl Eq for PduEvent {}
|
||||
impl PartialEq for PduEvent {
|
||||
fn eq(&self, other: &Self) -> bool {
|
||||
self.event_id == other.event_id
|
||||
}
|
||||
fn eq(&self, other: &Self) -> bool { self.event_id == other.event_id }
|
||||
}
|
||||
impl PartialOrd for PduEvent {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
|
||||
Some(self.cmp(other))
|
||||
}
|
||||
fn partial_cmp(&self, other: &Self) -> Option<Ordering> { Some(self.cmp(other)) }
|
||||
}
|
||||
impl Ord for PduEvent {
|
||||
fn cmp(&self, other: &Self) -> Ordering {
|
||||
self.event_id.cmp(&other.event_id)
|
||||
}
|
||||
fn cmp(&self, other: &Self) -> Ordering { self.event_id.cmp(&other.event_id) }
|
||||
}
|
||||
|
||||
/// Generates a correct eventId for the incoming pdu.
|
||||
///
|
||||
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String, CanonicalJsonValue>`.
|
||||
/// Returns a tuple of the new `EventId` and the PDU as a `BTreeMap<String,
|
||||
/// CanonicalJsonValue>`.
|
||||
pub(crate) fn gen_event_id_canonical_json(
|
||||
pdu: &RawJsonValue,
|
||||
room_version_id: &RoomVersionId,
|
||||
pdu: &RawJsonValue, room_version_id: &RoomVersionId,
|
||||
) -> crate::Result<(OwnedEventId, CanonicalJsonObject)> {
|
||||
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
|
||||
warn!("Error parsing incoming event {:?}: {:?}", pdu, e);
|
||||
|
@ -389,8 +352,7 @@ pub(crate) fn gen_event_id_canonical_json(
|
|||
let event_id = format!(
|
||||
"${}",
|
||||
// Anything higher than version3 behaves the same
|
||||
ruma::signatures::reference_hash(&value, room_version_id)
|
||||
.expect("ruma can calculate reference hashes")
|
||||
ruma::signatures::reference_hash(&value, room_version_id).expect("ruma can calculate reference hashes")
|
||||
)
|
||||
.try_into()
|
||||
.expect("ruma's reference hashes are valid event ids");
|
||||
|
|
|
@ -1,9 +1,10 @@
|
|||
use crate::Result;
|
||||
use ruma::{
|
||||
api::client::push::{set_pusher, Pusher},
|
||||
UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn set_pusher(&self, sender: &UserId, pusher: set_pusher::v3::PusherAction) -> Result<()>;
|
||||
|
||||
|
@ -11,6 +12,5 @@ pub trait Data: Send + Sync {
|
|||
|
||||
fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>>;
|
||||
|
||||
fn get_pushkeys<'a>(&'a self, sender: &UserId)
|
||||
-> Box<dyn Iterator<Item = Result<String>> + 'a>;
|
||||
fn get_pushkeys<'a>(&'a self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>> + 'a>;
|
||||
}
|
||||
|
|
|
@ -1,9 +1,8 @@
|
|||
mod data;
|
||||
pub use data::Data;
|
||||
use ruma::{events::AnySyncTimelineEvent, push::PushConditionPowerLevelsCtx};
|
||||
use std::{fmt::Debug, mem};
|
||||
|
||||
use crate::{services, Error, PduEvent, Result};
|
||||
use bytes::BytesMut;
|
||||
pub use data::Data;
|
||||
use ruma::{
|
||||
api::{
|
||||
client::push::{set_pusher, Pusher, PusherKind},
|
||||
|
@ -13,15 +12,17 @@ use ruma::{
|
|||
},
|
||||
IncomingResponse, MatrixVersion, OutgoingRequest, SendAccessToken,
|
||||
},
|
||||
events::{room::power_levels::RoomPowerLevelsEventContent, StateEventType, TimelineEventType},
|
||||
push::{Action, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
||||
events::{
|
||||
room::power_levels::RoomPowerLevelsEventContent, AnySyncTimelineEvent, StateEventType, TimelineEventType,
|
||||
},
|
||||
push::{Action, PushConditionPowerLevelsCtx, PushConditionRoomCtx, PushFormat, Ruleset, Tweak},
|
||||
serde::Raw,
|
||||
uint, RoomId, UInt, UserId,
|
||||
};
|
||||
|
||||
use std::{fmt::Debug, mem};
|
||||
use tracing::{info, warn};
|
||||
|
||||
use crate::{services, Error, PduEvent, Result};
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
}
|
||||
|
@ -35,31 +36,21 @@ impl Service {
|
|||
self.db.get_pusher(sender, pushkey)
|
||||
}
|
||||
|
||||
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> {
|
||||
self.db.get_pushers(sender)
|
||||
}
|
||||
pub fn get_pushers(&self, sender: &UserId) -> Result<Vec<Pusher>> { self.db.get_pushers(sender) }
|
||||
|
||||
pub fn get_pushkeys(&self, sender: &UserId) -> Box<dyn Iterator<Item = Result<String>>> {
|
||||
self.db.get_pushkeys(sender)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, destination, request))]
|
||||
pub async fn send_request<T>(
|
||||
&self,
|
||||
destination: &str,
|
||||
request: T,
|
||||
) -> Result<T::IncomingResponse>
|
||||
pub async fn send_request<T>(&self, destination: &str, request: T) -> Result<T::IncomingResponse>
|
||||
where
|
||||
T: OutgoingRequest + Debug,
|
||||
{
|
||||
let destination = destination.replace(services().globals.notification_push_path(), "");
|
||||
|
||||
let http_request = request
|
||||
.try_into_http_request::<BytesMut>(
|
||||
&destination,
|
||||
SendAccessToken::IfRequired(""),
|
||||
&[MatrixVersion::V1_0],
|
||||
)
|
||||
.try_into_http_request::<BytesMut>(&destination, SendAccessToken::IfRequired(""), &[MatrixVersion::V1_0])
|
||||
.map_err(|e| {
|
||||
warn!("Failed to find destination {}: {}", destination, e);
|
||||
Error::BadServerResponse("Invalid destination")
|
||||
|
@ -72,24 +63,16 @@ impl Service {
|
|||
//*reqwest_request.timeout_mut() = Some(Duration::from_secs(5));
|
||||
|
||||
let url = reqwest_request.url().clone();
|
||||
let response = services()
|
||||
.globals
|
||||
.default_client()
|
||||
.execute(reqwest_request)
|
||||
.await;
|
||||
let response = services().globals.default_client().execute(reqwest_request).await;
|
||||
|
||||
match response {
|
||||
Ok(mut response) => {
|
||||
// reqwest::Response -> http::Response conversion
|
||||
let status = response.status();
|
||||
let mut http_response_builder = http::Response::builder()
|
||||
.status(status)
|
||||
.version(response.version());
|
||||
let mut http_response_builder = http::Response::builder().status(status).version(response.version());
|
||||
mem::swap(
|
||||
response.headers_mut(),
|
||||
http_response_builder
|
||||
.headers_mut()
|
||||
.expect("http::response::Builder is usable"),
|
||||
http_response_builder.headers_mut().expect("http::response::Builder is usable"),
|
||||
);
|
||||
|
||||
let body = response.bytes().await.unwrap_or_else(|e| {
|
||||
|
@ -108,33 +91,23 @@ impl Service {
|
|||
}
|
||||
|
||||
let response = T::IncomingResponse::try_from_http_response(
|
||||
http_response_builder
|
||||
.body(body)
|
||||
.expect("reqwest body is valid http body"),
|
||||
http_response_builder.body(body).expect("reqwest body is valid http body"),
|
||||
);
|
||||
response.map_err(|_| {
|
||||
info!(
|
||||
"Push gateway returned invalid response bytes {}\n{}",
|
||||
destination, url
|
||||
);
|
||||
info!("Push gateway returned invalid response bytes {}\n{}", destination, url);
|
||||
Error::BadServerResponse("Push gateway returned bad response.")
|
||||
})
|
||||
}
|
||||
},
|
||||
Err(e) => {
|
||||
warn!("Could not send request to pusher {}: {}", destination, e);
|
||||
Err(e.into())
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, user, unread, pusher, ruleset, pdu))]
|
||||
pub async fn send_push_notice(
|
||||
&self,
|
||||
user: &UserId,
|
||||
unread: UInt,
|
||||
pusher: &Pusher,
|
||||
ruleset: Ruleset,
|
||||
pdu: &PduEvent,
|
||||
&self, user: &UserId, unread: UInt, pusher: &Pusher, ruleset: Ruleset, pdu: &PduEvent,
|
||||
) -> Result<()> {
|
||||
let mut notify = None;
|
||||
let mut tweaks = Vec::new();
|
||||
|
@ -150,19 +123,13 @@ impl Service {
|
|||
.transpose()?
|
||||
.unwrap_or_default();
|
||||
|
||||
for action in self.get_actions(
|
||||
user,
|
||||
&ruleset,
|
||||
&power_levels,
|
||||
&pdu.to_sync_room_event(),
|
||||
&pdu.room_id,
|
||||
)? {
|
||||
for action in self.get_actions(user, &ruleset, &power_levels, &pdu.to_sync_room_event(), &pdu.room_id)? {
|
||||
let n = match action {
|
||||
Action::Notify => true,
|
||||
Action::SetTweak(tweak) => {
|
||||
tweaks.push(tweak.clone());
|
||||
continue;
|
||||
}
|
||||
},
|
||||
_ => false,
|
||||
};
|
||||
|
||||
|
@ -185,12 +152,8 @@ impl Service {
|
|||
|
||||
#[tracing::instrument(skip(self, user, ruleset, pdu))]
|
||||
pub fn get_actions<'a>(
|
||||
&self,
|
||||
user: &UserId,
|
||||
ruleset: &'a Ruleset,
|
||||
power_levels: &RoomPowerLevelsEventContent,
|
||||
pdu: &Raw<AnySyncTimelineEvent>,
|
||||
room_id: &RoomId,
|
||||
&self, user: &UserId, ruleset: &'a Ruleset, power_levels: &RoomPowerLevelsEventContent,
|
||||
pdu: &Raw<AnySyncTimelineEvent>, room_id: &RoomId,
|
||||
) -> Result<&'a [Action]> {
|
||||
let power_levels = PushConditionPowerLevelsCtx {
|
||||
users: power_levels.users.clone(),
|
||||
|
@ -200,18 +163,9 @@ impl Service {
|
|||
|
||||
let ctx = PushConditionRoomCtx {
|
||||
room_id: room_id.to_owned(),
|
||||
member_count: UInt::from(
|
||||
services()
|
||||
.rooms
|
||||
.state_cache
|
||||
.room_joined_count(room_id)?
|
||||
.unwrap_or(1) as u32,
|
||||
),
|
||||
member_count: UInt::from(services().rooms.state_cache.room_joined_count(room_id)?.unwrap_or(1) as u32),
|
||||
user_id: user.to_owned(),
|
||||
user_display_name: services()
|
||||
.users
|
||||
.displayname(user)?
|
||||
.unwrap_or_else(|| user.localpart().to_owned()),
|
||||
user_display_name: services().users.displayname(user)?.unwrap_or_else(|| user.localpart().to_owned()),
|
||||
power_levels: Some(power_levels),
|
||||
};
|
||||
|
||||
|
@ -219,19 +173,14 @@ impl Service {
|
|||
}
|
||||
|
||||
#[tracing::instrument(skip(self, unread, pusher, tweaks, event))]
|
||||
async fn send_notice(
|
||||
&self,
|
||||
unread: UInt,
|
||||
pusher: &Pusher,
|
||||
tweaks: Vec<Tweak>,
|
||||
event: &PduEvent,
|
||||
) -> Result<()> {
|
||||
async fn send_notice(&self, unread: UInt, pusher: &Pusher, tweaks: Vec<Tweak>, event: &PduEvent) -> Result<()> {
|
||||
// TODO: email
|
||||
match &pusher.kind {
|
||||
PusherKind::Http(http) => {
|
||||
// TODO:
|
||||
// Two problems with this
|
||||
// 1. if "event_id_only" is the only format kind it seems we should never add more info
|
||||
// 1. if "event_id_only" is the only format kind it seems we should never add
|
||||
// more info
|
||||
// 2. can pusher/devices have conflicting formats
|
||||
let event_id_only = http.format == Some(PushFormat::EventIdOnly);
|
||||
|
||||
|
@ -254,36 +203,31 @@ impl Service {
|
|||
notifi.counts = NotificationCounts::new(unread, uint!(0));
|
||||
|
||||
if event.kind == TimelineEventType::RoomEncrypted
|
||||
|| tweaks
|
||||
.iter()
|
||||
.any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
|
||||
|| tweaks.iter().any(|t| matches!(t, Tweak::Highlight(true) | Tweak::Sound(_)))
|
||||
{
|
||||
notifi.prio = NotificationPriority::High;
|
||||
}
|
||||
|
||||
if event_id_only {
|
||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
|
||||
.await?;
|
||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?;
|
||||
} else {
|
||||
notifi.sender = Some(event.sender.clone());
|
||||
notifi.event_type = Some(event.kind.clone());
|
||||
notifi.content = serde_json::value::to_raw_value(&event.content).ok();
|
||||
|
||||
if event.kind == TimelineEventType::RoomMember {
|
||||
notifi.user_is_target =
|
||||
event.state_key.as_deref() == Some(event.sender.as_str());
|
||||
notifi.user_is_target = event.state_key.as_deref() == Some(event.sender.as_str());
|
||||
}
|
||||
|
||||
notifi.sender_display_name = services().users.displayname(&event.sender)?;
|
||||
|
||||
notifi.room_name = services().rooms.state_accessor.get_name(&event.room_id)?;
|
||||
|
||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi))
|
||||
.await?;
|
||||
self.send_request(&http.url, send_event_notification::v1::Request::new(notifi)).await?;
|
||||
}
|
||||
|
||||
Ok(())
|
||||
}
|
||||
},
|
||||
// TODO: Handle email
|
||||
//PusherKind::Email(_) => Ok(()),
|
||||
_ => Ok(()),
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::Result;
|
||||
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Creates or updates the alias to the given room id.
|
||||
fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()>;
|
||||
|
@ -13,12 +14,9 @@ pub trait Data: Send + Sync {
|
|||
|
||||
/// Returns all local aliases that point to the given room
|
||||
fn local_aliases_for_room<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
&'a self, room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a>;
|
||||
|
||||
/// Returns all local aliases on the server
|
||||
fn all_local_aliases<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a>;
|
||||
fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a>;
|
||||
}
|
||||
|
|
|
@ -1,9 +1,9 @@
|
|||
mod data;
|
||||
|
||||
pub use data::Data;
|
||||
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{OwnedRoomAliasId, OwnedRoomId, RoomAliasId, RoomId};
|
||||
|
||||
pub struct Service {
|
||||
pub db: &'static dyn Data,
|
||||
|
@ -11,14 +11,10 @@ pub struct Service {
|
|||
|
||||
impl Service {
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> {
|
||||
self.db.set_alias(alias, room_id)
|
||||
}
|
||||
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId) -> Result<()> { self.db.set_alias(alias, room_id) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
|
||||
self.db.remove_alias(alias)
|
||||
}
|
||||
pub fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> { self.db.remove_alias(alias) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
|
||||
|
@ -27,16 +23,13 @@ impl Service {
|
|||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn local_aliases_for_room<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
&'a self, room_id: &RoomId,
|
||||
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a> {
|
||||
self.db.local_aliases_for_room(room_id)
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn all_local_aliases<'a>(
|
||||
&'a self,
|
||||
) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||
pub fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
|
||||
self.db.all_local_aliases()
|
||||
}
|
||||
}
|
||||
|
|
|
@ -1,11 +1,8 @@
|
|||
use crate::Result;
|
||||
use std::{collections::HashSet, sync::Arc};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
fn get_cached_eventid_authchain(
|
||||
&self,
|
||||
shorteventid: &[u64],
|
||||
) -> Result<Option<Arc<HashSet<u64>>>>;
|
||||
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>)
|
||||
-> Result<()>;
|
||||
fn get_cached_eventid_authchain(&self, shorteventid: &[u64]) -> Result<Option<Arc<HashSet<u64>>>>;
|
||||
fn cache_auth_chain(&self, shorteventid: Vec<u64>, auth_chain: Arc<HashSet<u64>>) -> Result<()>;
|
||||
}
|
||||
|
|
|
@ -26,9 +26,7 @@ impl Service {
|
|||
|
||||
#[tracing::instrument(skip(self, starting_events))]
|
||||
pub async fn get_auth_chain<'a>(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
starting_events: Vec<Arc<EventId>>,
|
||||
&self, room_id: &RoomId, starting_events: Vec<Arc<EventId>>,
|
||||
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
|
||||
const NUM_BUCKETS: usize = 50;
|
||||
|
||||
|
@ -55,11 +53,7 @@ impl Service {
|
|||
}
|
||||
|
||||
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
|
||||
if let Some(cached) = services()
|
||||
.rooms
|
||||
.auth_chain
|
||||
.get_cached_eventid_authchain(&chunk_key)?
|
||||
{
|
||||
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&chunk_key)? {
|
||||
hits += 1;
|
||||
full_auth_chain.extend(cached.iter().copied());
|
||||
continue;
|
||||
|
@ -71,20 +65,13 @@ impl Service {
|
|||
let mut misses2 = 0;
|
||||
let mut i = 0;
|
||||
for (sevent_id, event_id) in chunk {
|
||||
if let Some(cached) = services()
|
||||
.rooms
|
||||
.auth_chain
|
||||
.get_cached_eventid_authchain(&[sevent_id])?
|
||||
{
|
||||
if let Some(cached) = services().rooms.auth_chain.get_cached_eventid_authchain(&[sevent_id])? {
|
||||
hits2 += 1;
|
||||
chunk_cache.extend(cached.iter().copied());
|
||||
} else {
|
||||
misses2 += 1;
|
||||
let auth_chain = Arc::new(self.get_auth_chain_inner(room_id, &event_id)?);
|
||||
services()
|
||||
.rooms
|
||||
.auth_chain
|
||||
.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
|
||||
services().rooms.auth_chain.cache_auth_chain(vec![sevent_id], Arc::clone(&auth_chain))?;
|
||||
debug!(
|
||||
event_id = ?event_id,
|
||||
chain_length = ?auth_chain.len(),
|
||||
|
@ -105,10 +92,7 @@ impl Service {
|
|||
"Chunk missed",
|
||||
);
|
||||
let chunk_cache = Arc::new(chunk_cache);
|
||||
services()
|
||||
.rooms
|
||||
.auth_chain
|
||||
.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
|
||||
services().rooms.auth_chain.cache_auth_chain(chunk_key, Arc::clone(&chunk_cache))?;
|
||||
full_auth_chain.extend(chunk_cache.iter());
|
||||
}
|
||||
|
||||
|
@ -119,9 +103,7 @@ impl Service {
|
|||
"Auth chain stats",
|
||||
);
|
||||
|
||||
Ok(full_auth_chain
|
||||
.into_iter()
|
||||
.filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
|
||||
Ok(full_auth_chain.into_iter().filter_map(move |sid| services().rooms.short.get_eventid_from_short(sid).ok()))
|
||||
}
|
||||
|
||||
#[tracing::instrument(skip(self, event_id))]
|
||||
|
@ -136,23 +118,20 @@ impl Service {
|
|||
return Err(Error::BadRequest(ErrorKind::Forbidden, "Evil event in db"));
|
||||
}
|
||||
for auth_event in &pdu.auth_events {
|
||||
let sauthevent = services()
|
||||
.rooms
|
||||
.short
|
||||
.get_or_create_shorteventid(auth_event)?;
|
||||
let sauthevent = services().rooms.short.get_or_create_shorteventid(auth_event)?;
|
||||
|
||||
if !found.contains(&sauthevent) {
|
||||
found.insert(sauthevent);
|
||||
todo.push(auth_event.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
Ok(None) => {
|
||||
warn!(?event_id, "Could not find pdu mentioned in auth events");
|
||||
}
|
||||
},
|
||||
Err(error) => {
|
||||
error!(?event_id, ?error, "Could not load event in auth chain");
|
||||
}
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
use crate::Result;
|
||||
use ruma::{OwnedRoomId, RoomId};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Adds the room to the public room directory
|
||||
fn set_public(&self, room_id: &RoomId) -> Result<()>;
|
||||
|
|
|
@ -11,22 +11,14 @@ pub struct Service {
|
|||
|
||||
impl Service {
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn set_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.db.set_public(room_id)
|
||||
}
|
||||
pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
|
||||
self.db.set_not_public(room_id)
|
||||
}
|
||||
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
|
||||
self.db.is_public_room(room_id)
|
||||
}
|
||||
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { self.db.is_public_room(room_id) }
|
||||
|
||||
#[tracing::instrument(skip(self))]
|
||||
pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ {
|
||||
self.db.public_rooms()
|
||||
}
|
||||
pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.public_rooms() }
|
||||
}
|
||||
|
|
|
@ -1,33 +1,27 @@
|
|||
use ruma::{events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId};
|
||||
|
||||
use crate::Result;
|
||||
use ruma::{
|
||||
events::presence::PresenceEvent, presence::PresenceState, OwnedUserId, RoomId, UInt, UserId,
|
||||
};
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Returns the latest presence event for the given user in the given room.
|
||||
fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>>;
|
||||
|
||||
/// Pings the presence of the given user in the given room, setting the specified state.
|
||||
/// Pings the presence of the given user in the given room, setting the
|
||||
/// specified state.
|
||||
fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()>;
|
||||
|
||||
/// Adds a presence event which will be saved until a new event replaces it.
|
||||
fn set_presence(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
presence_state: PresenceState,
|
||||
currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>,
|
||||
status_msg: Option<String>,
|
||||
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||
) -> Result<()>;
|
||||
|
||||
/// Removes the presence record for the given user from the database.
|
||||
fn remove_presence(&self, user_id: &UserId) -> Result<()>;
|
||||
|
||||
/// Returns the most recent presence updates that happened after the event with id `since`.
|
||||
/// Returns the most recent presence updates that happened after the event
|
||||
/// with id `since`.
|
||||
fn presence_since<'a>(
|
||||
&'a self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
&'a self, room_id: &RoomId, since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)> + 'a>;
|
||||
}
|
||||
|
|
|
@ -15,7 +15,8 @@ use tracing::debug;
|
|||
|
||||
use crate::{services, utils, Error, Result};
|
||||
|
||||
/// Represents data required to be kept in order to implement the presence specification.
|
||||
/// Represents data required to be kept in order to implement the presence
|
||||
/// specification.
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
pub struct Presence {
|
||||
pub state: PresenceState,
|
||||
|
@ -27,11 +28,7 @@ pub struct Presence {
|
|||
|
||||
impl Presence {
|
||||
pub fn new(
|
||||
state: PresenceState,
|
||||
currently_active: bool,
|
||||
last_active_ts: u64,
|
||||
last_count: u64,
|
||||
status_msg: Option<String>,
|
||||
state: PresenceState, currently_active: bool, last_active_ts: u64, last_count: u64, status_msg: Option<String>,
|
||||
) -> Self {
|
||||
Self {
|
||||
state,
|
||||
|
@ -43,13 +40,11 @@ impl Presence {
|
|||
}
|
||||
|
||||
pub fn from_json_bytes(bytes: &[u8]) -> Result<Self> {
|
||||
serde_json::from_slice(bytes)
|
||||
.map_err(|_| Error::bad_database("Invalid presence data in database"))
|
||||
serde_json::from_slice(bytes).map_err(|_| Error::bad_database("Invalid presence data in database"))
|
||||
}
|
||||
|
||||
pub fn to_json_bytes(&self) -> Result<Vec<u8>> {
|
||||
serde_json::to_vec(self)
|
||||
.map_err(|_| Error::bad_database("Could not serialize Presence to JSON"))
|
||||
serde_json::to_vec(self).map_err(|_| Error::bad_database("Could not serialize Presence to JSON"))
|
||||
}
|
||||
|
||||
/// Creates a PresenceEvent from available data.
|
||||
|
@ -58,9 +53,7 @@ impl Presence {
|
|||
let last_active_ago = if self.currently_active {
|
||||
None
|
||||
} else {
|
||||
Some(UInt::new_saturating(
|
||||
now.saturating_sub(self.last_active_ts),
|
||||
))
|
||||
Some(UInt::new_saturating(now.saturating_sub(self.last_active_ts)))
|
||||
};
|
||||
|
||||
Ok(PresenceEvent {
|
||||
|
@ -83,49 +76,31 @@ pub struct Service {
|
|||
|
||||
impl Service {
|
||||
/// Returns the latest presence event for the given user in the given room.
|
||||
pub fn get_presence(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
) -> Result<Option<PresenceEvent>> {
|
||||
pub fn get_presence(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<PresenceEvent>> {
|
||||
self.db.get_presence(room_id, user_id)
|
||||
}
|
||||
|
||||
/// Pings the presence of the given user in the given room, setting the specified state.
|
||||
/// Pings the presence of the given user in the given room, setting the
|
||||
/// specified state.
|
||||
pub fn ping_presence(&self, user_id: &UserId, new_state: PresenceState) -> Result<()> {
|
||||
self.db.ping_presence(user_id, new_state)
|
||||
}
|
||||
|
||||
/// Adds a presence event which will be saved until a new event replaces it.
|
||||
pub fn set_presence(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
user_id: &UserId,
|
||||
presence_state: PresenceState,
|
||||
currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>,
|
||||
status_msg: Option<String>,
|
||||
&self, room_id: &RoomId, user_id: &UserId, presence_state: PresenceState, currently_active: Option<bool>,
|
||||
last_active_ago: Option<UInt>, status_msg: Option<String>,
|
||||
) -> Result<()> {
|
||||
self.db.set_presence(
|
||||
room_id,
|
||||
user_id,
|
||||
presence_state,
|
||||
currently_active,
|
||||
last_active_ago,
|
||||
status_msg,
|
||||
)
|
||||
self.db.set_presence(room_id, user_id, presence_state, currently_active, last_active_ago, status_msg)
|
||||
}
|
||||
|
||||
/// Removes the presence record for the given user from the database.
|
||||
pub fn remove_presence(&self, user_id: &UserId) -> Result<()> {
|
||||
self.db.remove_presence(user_id)
|
||||
}
|
||||
pub fn remove_presence(&self, user_id: &UserId) -> Result<()> { self.db.remove_presence(user_id) }
|
||||
|
||||
/// Returns the most recent presence updates that happened after the event with id `since`.
|
||||
/// Returns the most recent presence updates that happened after the event
|
||||
/// with id `since`.
|
||||
pub fn presence_since(
|
||||
&self,
|
||||
room_id: &RoomId,
|
||||
since: u64,
|
||||
&self, room_id: &RoomId, since: u64,
|
||||
) -> Box<dyn Iterator<Item = (OwnedUserId, u64, PresenceEvent)>> {
|
||||
self.db.presence_since(room_id, since)
|
||||
}
|
||||
|
@ -167,11 +142,7 @@ fn process_presence_timer(user_id: OwnedUserId) -> Result<()> {
|
|||
let mut status_msg = None;
|
||||
|
||||
for room_id in services().rooms.state_cache.rooms_joined(&user_id) {
|
||||
let presence_event = services()
|
||||
.rooms
|
||||
.edus
|
||||
.presence
|
||||
.get_presence(&room_id?, &user_id)?;
|
||||
let presence_event = services().rooms.edus.presence.get_presence(&room_id?, &user_id)?;
|
||||
|
||||
if let Some(presence_event) = presence_event {
|
||||
presence_state = presence_event.content.presence;
|
||||
|
@ -183,12 +154,8 @@ fn process_presence_timer(user_id: OwnedUserId) -> Result<()> {
|
|||
}
|
||||
|
||||
let new_state = match (&presence_state, last_active_ago.map(u64::from)) {
|
||||
(PresenceState::Online, Some(ago)) if ago >= idle_timeout => {
|
||||
Some(PresenceState::Unavailable)
|
||||
}
|
||||
(PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => {
|
||||
Some(PresenceState::Offline)
|
||||
}
|
||||
(PresenceState::Online, Some(ago)) if ago >= idle_timeout => Some(PresenceState::Unavailable),
|
||||
(PresenceState::Unavailable, Some(ago)) if ago >= offline_timeout => Some(PresenceState::Offline),
|
||||
_ => None,
|
||||
};
|
||||
|
||||
|
|
|
@ -1,25 +1,21 @@
|
|||
use crate::Result;
|
||||
use ruma::{
|
||||
events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent},
|
||||
serde::Raw,
|
||||
OwnedUserId, RoomId, UserId,
|
||||
};
|
||||
|
||||
use crate::Result;
|
||||
|
||||
type AnySyncEphemeralRoomEventIter<'a> =
|
||||
Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>;
|
||||
|
||||
pub trait Data: Send + Sync {
|
||||
/// Replaces the previous read receipt.
|
||||
fn readreceipt_update(
|
||||
&self,
|
||||
user_id: &UserId,
|
||||
room_id: &RoomId,
|
||||
event: ReceiptEvent,
|
||||
) -> Result<()>;
|
||||
fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: ReceiptEvent) -> Result<()>;
|
||||
|
||||
/// Returns an iterator over the most recent read_receipts in a room that happened after the event with id `since`.
|
||||
fn readreceipts_since(&self, room_id: &RoomId, since: u64)
|
||||
-> AnySyncEphemeralRoomEventIter<'_>;
|
||||
/// Returns an iterator over the most recent read_receipts in a room that
|
||||
/// happened after the event with id `since`.
|
||||
fn readreceipts_since(&self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'_>;
|
||||
|
||||
/// Sets a private read marker at `count`.
|
||||
fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()>;
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue