From 76c5942b4f7ebb3b11b31f2df5a3e405166cd945 Mon Sep 17 00:00:00 2001 From: strawberry Date: Sun, 28 Apr 2024 13:18:09 -0400 Subject: [PATCH] use user_is_local and server_is_ours more, remove few double filters Signed-off-by: strawberry --- src/api/client_server/directory.rs | 4 ++-- src/api/client_server/keys.rs | 2 +- src/api/client_server/media.rs | 10 ++++++---- src/api/client_server/profile.rs | 8 ++++---- src/api/client_server/state.rs | 6 ++++-- src/api/client_server/to_device.rs | 4 ++-- src/api/server_server.rs | 20 ++++++-------------- src/service/rooms/state_cache/mod.rs | 4 ++-- src/service/rooms/timeline/mod.rs | 27 ++++++++++++++------------- src/service/sending/mod.rs | 8 ++++---- src/service/sending/sender.rs | 4 ++-- 11 files changed, 47 insertions(+), 50 deletions(-) diff --git a/src/api/client_server/directory.rs b/src/api/client_server/directory.rs index 563f7e08..c8f0a23a 100644 --- a/src/api/client_server/directory.rs +++ b/src/api/client_server/directory.rs @@ -24,7 +24,7 @@ use ruma::{ }; use tracing::{error, info, warn}; -use crate::{services, Error, Result, Ruma}; +use crate::{services, utils::server_name::server_is_ours, Error, Result, Ruma}; /// # `POST /_matrix/client/v3/publicRooms` /// @@ -173,7 +173,7 @@ pub(crate) async fn get_room_visibility_route( pub(crate) async fn get_public_rooms_filtered_helper( server: Option<&ServerName>, limit: Option, since: Option<&str>, filter: &Filter, _network: &RoomNetwork, ) -> Result { - if let Some(other_server) = server.filter(|server| *server != services().globals.server_name().as_str()) { + if let Some(other_server) = server.filter(|server_name| !server_is_ours(server_name)) { let response = services() .sending .send_federation_request( diff --git a/src/api/client_server/keys.rs b/src/api/client_server/keys.rs index d0206dfc..7909767d 100644 --- a/src/api/client_server/keys.rs +++ b/src/api/client_server/keys.rs @@ -458,7 +458,7 @@ pub(crate) async fn claim_keys_helper( let mut get_over_federation = BTreeMap::new(); for (user_id, map) in one_time_keys_input { - if user_id.server_name() != services().globals.server_name() { + if !user_is_local(user_id) { get_over_federation .entry(user_id.server_name()) .or_insert_with(Vec::new) diff --git a/src/api/client_server/media.rs b/src/api/client_server/media.rs index bed79bc1..9f4c45b1 100644 --- a/src/api/client_server/media.rs +++ b/src/api/client_server/media.rs @@ -16,7 +16,9 @@ use webpage::HTML; use crate::{ debug_warn, service::media::{FileMeta, UrlPreviewData}, - services, utils, Error, Result, Ruma, RumaResponse, + services, + utils::{self, server_name::server_is_ours}, + Error, Result, Ruma, RumaResponse, }; /// generated MXC ID (`media-id`) length @@ -181,7 +183,7 @@ pub(crate) async fn get_content_route(body: Ruma) -> R cross_origin_resource_policy: Some("cross-origin".to_owned()), cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()), }) - } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + } else if !server_is_ours(&body.server_name) && body.allow_remote { get_remote_content( &mxc, &body.server_name, @@ -243,7 +245,7 @@ pub(crate) async fn get_content_as_filename_route( cross_origin_resource_policy: Some("cross-origin".to_owned()), cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()), }) - } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + } else if !server_is_ours(&body.server_name) && body.allow_remote { match get_remote_content( &mxc, &body.server_name, @@ -324,7 +326,7 @@ pub(crate) async fn get_content_thumbnail_route( cross_origin_resource_policy: Some("cross-origin".to_owned()), cache_control: Some(CACHE_CONTROL_IMMUTABLE.into()), }) - } else if &*body.server_name != services().globals.server_name() && body.allow_remote { + } else if !server_is_ours(&body.server_name) && body.allow_remote { if services() .globals .prevent_media_downloads_from() diff --git a/src/api/client_server/profile.rs b/src/api/client_server/profile.rs index 68a12db0..a8cf9af2 100644 --- a/src/api/client_server/profile.rs +++ b/src/api/client_server/profile.rs @@ -13,7 +13,7 @@ use ruma::{ }; use serde_json::value::to_raw_value; -use crate::{service::pdu::PduBuilder, services, Error, Result, Ruma}; +use crate::{service::pdu::PduBuilder, services, utils::user_id::user_is_local, Error, Result, Ruma}; /// # `PUT /_matrix/client/r0/profile/{userId}/displayname` /// @@ -105,7 +105,7 @@ pub(crate) async fn set_displayname_route( pub(crate) async fn get_displayname_route( body: Ruma, ) -> Result { - if body.user_id.server_name() != services().globals.server_name() { + if !user_is_local(&body.user_id) { // Create and update our local copy of the user if let Ok(response) = services() .sending @@ -247,7 +247,7 @@ pub(crate) async fn set_avatar_url_route( pub(crate) async fn get_avatar_url_route( body: Ruma, ) -> Result { - if body.user_id.server_name() != services().globals.server_name() { + if !user_is_local(&body.user_id) { // Create and update our local copy of the user if let Ok(response) = services() .sending @@ -303,7 +303,7 @@ pub(crate) async fn get_avatar_url_route( /// - If user is on another server and we do not have a local copy already, /// fetch profile over federation. pub(crate) async fn get_profile_route(body: Ruma) -> Result { - if body.user_id.server_name() != services().globals.server_name() { + if !user_is_local(&body.user_id) { // Create and update our local copy of the user if let Ok(response) = services() .sending diff --git a/src/api/client_server/state.rs b/src/api/client_server/state.rs index b6a9f0ae..aee8b21e 100644 --- a/src/api/client_server/state.rs +++ b/src/api/client_server/state.rs @@ -20,7 +20,9 @@ use tracing::{error, log::warn}; use crate::{ service::{self, pdu::PduBuilder}, - services, Error, Result, Ruma, RumaResponse, + services, + utils::server_name::server_is_ours, + Error, Result, Ruma, RumaResponse, }; /// # `PUT /_matrix/client/*/rooms/{roomId}/state/{eventType}/{stateKey}` @@ -279,7 +281,7 @@ async fn send_state_event_for_key_helper( } for alias in aliases { - if alias.server_name() != services().globals.server_name() + if !server_is_ours(alias.server_name()) || services() .rooms .alias diff --git a/src/api/client_server/to_device.rs b/src/api/client_server/to_device.rs index 15521aa8..e85b991f 100644 --- a/src/api/client_server/to_device.rs +++ b/src/api/client_server/to_device.rs @@ -8,7 +8,7 @@ use ruma::{ to_device::DeviceIdOrAllDevices, }; -use crate::{services, Error, Result, Ruma}; +use crate::{services, utils::user_id::user_is_local, Error, Result, Ruma}; /// # `PUT /_matrix/client/r0/sendToDevice/{eventType}/{txnId}` /// @@ -30,7 +30,7 @@ pub(crate) async fn send_event_to_device_route( for (target_user_id, map) in &body.messages { for (target_device_id_maybe, event) in map { - if target_user_id.server_name() != services().globals.server_name() { + if !user_is_local(target_user_id) { let mut map = BTreeMap::new(); map.insert(target_device_id_maybe.clone(), event.clone()); let mut messages = BTreeMap::new(); diff --git a/src/api/server_server.rs b/src/api/server_server.rs index 5651cbde..66fccc7e 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -56,7 +56,7 @@ use crate::{ debug_error, service::pdu::{gen_event_id_canonical_json, PduBuilder}, services, - utils::{self, user_id::user_is_local}, + utils::{self, server_name::server_is_ours, user_id::user_is_local}, Error, PduEvent, Result, Ruma, }; @@ -1456,7 +1456,7 @@ async fn create_leave_event(sender_servername: &ServerName, room_id: &RoomId, pd .state_cache .room_servers(room_id) .filter_map(Result::ok) - .filter(|server| &**server != services().globals.server_name()); + .filter(|server| !server_is_ours(server)); services().sending.send_pdu_servers(servers, &pdu_id)?; @@ -1651,7 +1651,7 @@ pub(crate) async fn create_invite_route(body: Ruma) /// /// Gets information on all devices of the user. pub(crate) async fn get_devices_route(body: Ruma) -> Result { - if body.user_id.server_name() != services().globals.server_name() { + if !user_is_local(&body.user_id) { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Tried to access user from other server.", @@ -1757,7 +1757,7 @@ pub(crate) async fn get_profile_information_route( )); } - if body.user_id.server_name() != services().globals.server_name() { + if !server_is_ours(body.user_id.server_name()) { return Err(Error::BadRequest( ErrorKind::InvalidParam, "User does not belong to this server", @@ -1796,11 +1796,7 @@ pub(crate) async fn get_profile_information_route( /// /// Gets devices and identity keys for the given users. pub(crate) async fn get_keys_route(body: Ruma) -> Result { - if body - .device_keys - .iter() - .any(|(u, _)| u.server_name() != services().globals.server_name()) - { + if body.device_keys.iter().any(|(u, _)| !user_is_local(u)) { return Err(Error::BadRequest( ErrorKind::InvalidParam, "User does not belong to this server.", @@ -1826,11 +1822,7 @@ pub(crate) async fn get_keys_route(body: Ruma) -> Result< /// /// Claims one-time keys. pub(crate) async fn claim_keys_route(body: Ruma) -> Result { - if body - .one_time_keys - .iter() - .any(|(u, _)| u.server_name() != services().globals.server_name()) - { + if body.one_time_keys.iter().any(|(u, _)| !user_is_local(u)) { return Err(Error::BadRequest( ErrorKind::InvalidParam, "Tried to access user from other server.", diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index fa7914e8..9f2ccdf4 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -19,7 +19,7 @@ use ruma::{ }; use tracing::{error, warn}; -use crate::{service::appservice::RegistrationInfo, services, Error, Result}; +use crate::{service::appservice::RegistrationInfo, services, utils::user_id::user_is_local, Error, Result}; mod data; @@ -43,7 +43,7 @@ impl Service { // TODO: use futures to update remote profiles without blocking the membership // update #[allow(clippy::collapsible_if)] - if user_id.server_name() != services().globals.server_name() { + if !user_is_local(user_id) { if !services().users.exists(user_id)? { services().users.create(user_id, None)?; } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index e9fa67ed..ea9f1613 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -41,7 +41,9 @@ use crate::{ appservice::NamespaceRegex, pdu::{EventHash, PduBuilder}, }, - services, utils, Error, PduEvent, Result, + services, + utils::{self, server_name::server_is_ours}, + Error, PduEvent, Result, }; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] @@ -852,8 +854,7 @@ impl Service { .state_cache .room_members(room_id) .filter_map(Result::ok) - .filter(|m| m.server_name() == server_name) - .filter(|m| m != target) + .filter(|m| server_is_ours(m.server_name()) && m != target) .count(); if count < 2 { warn!("Last admin cannot leave from admins room"); @@ -878,8 +879,7 @@ impl Service { .state_cache .room_members(room_id) .filter_map(Result::ok) - .filter(|m| m.server_name() == server_name) - .filter(|m| m != target) + .filter(|m| server_is_ours(m.server_name()) && m != target) .count(); if count < 2 { warn!("Last admin cannot be banned in admins room"); @@ -1059,8 +1059,9 @@ impl Service { .state_cache .room_servers(room_id) .filter_map(Result::ok) - .filter(|server| services().globals.trusted_servers().contains(server)) - .filter(|server| server != services().globals.server_name()), + .filter(|server_name| { + services().globals.trusted_servers().contains(server_name) && !server_is_ours(server_name) + }), ); // add server names from room aliases on the room ID @@ -1071,16 +1072,16 @@ impl Service { .collect::, _>>(); if let Ok(aliases) = &room_aliases { for alias in aliases { - if alias.server_name() != services().globals.server_name() { + if !server_is_ours(alias.server_name()) { servers.push(alias.server_name().to_owned()); } } } // add room ID server name for backfill server - if let Some(server) = room_id.server_name() { - if server != services().globals.server_name() { - servers.push(server.to_owned()); + if let Some(server_name) = room_id.server_name() { + if !server_is_ours(server_name) { + servers.push(server_name.to_owned()); } } @@ -1102,7 +1103,7 @@ impl Service { .iter() .filter(|(_, level)| **level > power_levels.users_default) .map(|(user_id, _)| user_id.server_name()) - .filter(|server| server != &services().globals.server_name()) + .filter(|server_name| !server_is_ours(server_name)) .map(ToOwned::to_owned), ); @@ -1110,7 +1111,7 @@ impl Service { if let Some(server_index) = servers .clone() .into_iter() - .position(|server| server == services().globals.server_name()) + .position(|server_name| server_is_ours(&server_name)) { servers.remove(server_index); } diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index 36e8f780..ed01e797 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -8,7 +8,7 @@ use ruma::{ use tokio::sync::Mutex; use tracing::warn; -use crate::{services, Config, Error, Result}; +use crate::{services, utils::server_name::server_is_ours, Config, Error, Result}; mod appservice; mod data; @@ -93,7 +93,7 @@ impl Service { .state_cache .room_servers(room_id) .filter_map(Result::ok) - .filter(|server| &**server != services().globals.server_name()); + .filter(|server_name| !server_is_ours(server_name)); self.send_pdu_servers(servers, pdu_id) } @@ -144,7 +144,7 @@ impl Service { .state_cache .room_servers(room_id) .filter_map(Result::ok) - .filter(|server| &**server != services().globals.server_name()); + .filter(|server_name| !server_is_ours(server_name)); self.send_edu_servers(servers, serialized) } @@ -183,7 +183,7 @@ impl Service { .state_cache .room_servers(room_id) .filter_map(Result::ok) - .filter(|server| &**server != services().globals.server_name()); + .filter(|server_name| !server_is_ours(server_name)); self.flush_servers(servers) } diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index 5ec7ebad..e7d00234 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -293,7 +293,7 @@ fn select_edus_presence( for (user_id, count, presence_bytes) in services().presence.presence_since(since) { *max_edu_count = cmp::max(count, *max_edu_count); - if user_id.server_name() != services().globals.server_name() { + if !user_is_local(&user_id) { continue; } @@ -341,7 +341,7 @@ fn select_edus_receipts( let (user_id, count, read_receipt) = r?; *max_edu_count = cmp::max(count, *max_edu_count); - if user_id.server_name() != services().globals.server_name() { + if !user_is_local(&user_id) { continue; }