use single global function for server name local and user local checking

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-04-28 11:35:25 -04:00 committed by June
parent 8f17d965b2
commit 9931e60050
14 changed files with 77 additions and 41 deletions

View file

@ -18,7 +18,9 @@ use tracing::{error, info, warn};
use super::{DEVICE_ID_LENGTH, SESSION_ID_LENGTH, TOKEN_LENGTH};
use crate::{
api::client_server::{self, join_room_by_id_helper},
service, services, utils, Error, Result, Ruma,
service, services,
utils::{self, user_id::user_is_local},
Error, Result, Ruma,
};
const RANDOM_USER_ID_LENGTH: usize = 10;
@ -40,7 +42,7 @@ pub(crate) async fn get_register_available_route(
// Validate user id
let user_id = UserId::parse_with_server_name(body.username.to_lowercase(), services().globals.server_name())
.ok()
.filter(|user_id| !user_id.is_historical() && user_id.server_name() == services().globals.server_name())
.filter(|user_id| !user_id.is_historical() && user_is_local(user_id))
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
// Check if username is creative enough
@ -125,9 +127,7 @@ pub(crate) async fn register_route(body: Ruma<register::v3::Request>) -> Result<
let proposed_user_id =
UserId::parse_with_server_name(username.to_lowercase(), services().globals.server_name())
.ok()
.filter(|user_id| {
!user_id.is_historical() && user_id.server_name() == services().globals.server_name()
})
.filter(|user_id| !user_id.is_historical() && user_is_local(user_id))
.ok_or(Error::BadRequest(ErrorKind::InvalidUsername, "Username is invalid."))?;
if services().users.exists(&proposed_user_id)? {

View file

@ -12,13 +12,13 @@ use ruma::{
};
use tracing::debug;
use crate::{debug_info, debug_warn, services, Error, Result, Ruma};
use crate::{debug_info, debug_warn, services, utils::server_name::server_is_ours, Error, Result, Ruma};
/// # `PUT /_matrix/client/v3/directory/room/{roomAlias}`
///
/// Creates a new room alias on this server.
pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) -> Result<create_alias::v3::Response> {
if body.room_alias.server_name() != services().globals.server_name() {
if !server_is_ours(body.room_alias.server_name()) {
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
}
@ -73,7 +73,7 @@ pub(crate) async fn create_alias_route(body: Ruma<create_alias::v3::Request>) ->
/// - TODO: additional access control checks
/// - TODO: Update canonical alias event
pub(crate) async fn delete_alias_route(body: Ruma<delete_alias::v3::Request>) -> Result<delete_alias::v3::Response> {
if body.room_alias.server_name() != services().globals.server_name() {
if !server_is_ours(body.room_alias.server_name()) {
return Err(Error::BadRequest(ErrorKind::InvalidParam, "Alias is from another server."));
}
@ -126,7 +126,7 @@ pub(crate) async fn get_alias_helper(
room_alias: OwnedRoomAliasId, servers: Option<Vec<OwnedServerName>>,
) -> Result<get_alias::v3::Response> {
debug!("get_alias_helper servers: {servers:?}");
if room_alias.server_name() != services().globals.server_name()
if !server_is_ours(room_alias.server_name())
&& (!servers
.as_ref()
.is_some_and(|servers| servers.contains(&services().globals.server_name().to_owned()))
@ -204,7 +204,7 @@ pub(crate) async fn get_alias_helper(
// room alias server first
if let Some(server_index) = servers
.iter()
.position(|server| server == services().globals.server_name())
.position(|server_name| server_is_ours(server_name))
{
servers.remove(server_index);
servers.insert(0, services().globals.server_name().to_owned());
@ -277,7 +277,7 @@ pub(crate) async fn get_alias_helper(
// insert our server as the very first choice if in list
if let Some(server_index) = servers
.iter()
.position(|server| server == services().globals.server_name())
.position(|server_name| server_is_ours(server_name))
{
servers.remove(server_index);
servers.insert(0, services().globals.server_name().to_owned());

View file

@ -20,7 +20,11 @@ use serde_json::json;
use tracing::debug;
use super::SESSION_ID_LENGTH;
use crate::{services, utils, Error, Result, Ruma};
use crate::{
services,
utils::{self, user_id::user_is_local},
Error, Result, Ruma,
};
/// # `POST /_matrix/client/r0/keys/upload`
///
@ -260,7 +264,7 @@ pub(crate) async fn get_keys_helper<F: Fn(&UserId) -> bool>(
for (user_id, device_ids) in device_keys_input {
let user_id: &UserId = user_id;
if user_id.server_name() != services().globals.server_name() {
if !user_is_local(user_id) {
get_over_federation
.entry(user_id.server_name())
.or_insert_with(Vec::new)

View file

@ -34,7 +34,9 @@ use tracing::{debug, error, info, trace, warn};
use super::get_alias_helper;
use crate::{
service::pdu::{gen_event_id_canonical_json, PduBuilder},
services, utils, Error, PduEvent, Result, Ruma,
services,
utils::{self, server_name::server_is_ours, user_id::user_is_local},
Error, PduEvent, Result, Ruma,
};
/// # `POST /_matrix/client/r0/rooms/{roomId}/join`
@ -1088,7 +1090,7 @@ pub(crate) async fn join_room_by_id_helper(
.state_cache
.room_members(room_id)
.filter_map(Result::ok)
.filter(|user| user.server_name() == services().globals.server_name())
.filter(|user| user_is_local(user))
.collect::<Vec<OwnedUserId>>();
let mut authorized_user: Option<OwnedUserId> = None;
@ -1150,7 +1152,7 @@ pub(crate) async fn join_room_by_id_helper(
if !restriction_rooms.is_empty()
&& servers
.iter()
.any(|s| *s != services().globals.server_name())
.any(|server_name| !server_is_ours(server_name))
{
info!(
"We couldn't do the join locally, maybe federation can help to satisfy the restricted join \
@ -1303,7 +1305,7 @@ async fn make_join_request(
let mut incompatible_room_version_count = 0;
for remote_server in servers {
if remote_server == services().globals.server_name() {
if server_is_ours(remote_server) {
continue;
}
info!("Asking {remote_server} for make_join ({make_join_counter})");
@ -1436,7 +1438,7 @@ pub(crate) async fn invite_helper(
));
}
if user_id.server_name() != services().globals.server_name() {
if !user_is_local(user_id) {
let (pdu, pdu_json, invite_room_state) = {
let mutex_state = Arc::clone(
services()

View file

@ -55,7 +55,9 @@ use crate::{
api::client_server::{self, claim_keys_helper, get_keys_helper},
debug_error,
service::pdu::{gen_event_id_canonical_json, PduBuilder},
services, utils, Error, PduEvent, Result, Ruma,
services,
utils::{self, user_id::user_is_local},
Error, PduEvent, Result, Ruma,
};
/// # `GET /_matrix/federation/v1/version`
@ -978,7 +980,7 @@ pub(crate) async fn create_join_event_template_route(
.state_cache
.room_members(&body.room_id)
.filter_map(Result::ok)
.filter(|user| user.server_name() == services().globals.server_name())
.filter(|user| user_is_local(user))
.collect();
let mut auth_user = None;