diff --git a/src/api/client_server/membership.rs b/src/api/client_server/membership.rs index 655dc0a4..244ea489 100644 --- a/src/api/client_server/membership.rs +++ b/src/api/client_server/membership.rs @@ -59,15 +59,21 @@ pub async fn join_room_by_id_route(body: Ruma) -> let mut servers = services() .rooms .state_cache - .invite_state(sender_user, &body.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect::>(); + .servers_invite_via(&body.room_id)? + .unwrap_or( + services() + .rooms + .state_cache + .invite_state(sender_user, &body.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect::>(), + ); if let Some(server) = body.room_id.server_name() { servers.push(server.into()); @@ -112,14 +118,21 @@ pub async fn join_room_by_id_or_alias_route( services() .rooms .state_cache - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()), + .servers_invite_via(&room_id)? + .unwrap_or( + services() + .rooms + .state_cache + .invite_state(sender_user, &room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect(), + ), ); if let Some(server) = room_id.server_name() { @@ -1328,6 +1341,7 @@ pub(crate) async fn invite_helper( room_version: room_version_id.clone(), event: PduEvent::convert_to_outgoing_federation_event(pdu_json.clone()), invite_room_state, + via: services().rooms.state_cache.servers_route_via(room_id).ok(), }, ) .await?; @@ -1483,18 +1497,15 @@ pub async fn leave_room(user_id: &UserId, room_id: &RoomId, reason: Option { error!("Trying to leave a room you are not a member of."); - services() - .rooms - .state_cache - .update_membership( - room_id, - user_id, - RoomMemberEventContent::new(MembershipState::Leave), - user_id, - None, - true, - ) - .await?; + services().rooms.state_cache.update_membership( + room_id, + user_id, + RoomMemberEventContent::new(MembershipState::Leave), + user_id, + None, + None, + true, + )?; return Ok(()); }, Some(e) => e, @@ -1573,14 +1581,21 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { .invite_state(user_id, room_id)? .ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; - let servers: HashSet<_> = invite_state - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(); + let servers: HashSet = services() + .rooms + .state_cache + .servers_invite_via(room_id)? + .map_or( + invite_state + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect::>(), + HashSet::from_iter, + ); for remote_server in servers { let make_leave_response = services() diff --git a/src/api/server_server.rs b/src/api/server_server.rs index d5b60dbd..4faa4747 100644 --- a/src/api/server_server.rs +++ b/src/api/server_server.rs @@ -1253,6 +1253,12 @@ pub async fn create_invite_route(body: Ruma) -> Resu )); } + if let Some(via) = &body.via { + if via.is_empty() { + return Err(Error::BadRequest(ErrorKind::InvalidParam, "via field must not be empty.")); + } + } + let mut signed_event = utils::to_canonical_object(&body.event).map_err(|e| { error!("Failed to convert invite event to canonical JSON: {}", e); Error::BadRequest(ErrorKind::InvalidParam, "Invite event is invalid.") @@ -1339,18 +1345,15 @@ pub async fn create_invite_route(body: Ruma) -> Resu .state_cache .server_in_room(services().globals.server_name(), &body.room_id)? { - services() - .rooms - .state_cache - .update_membership( - &body.room_id, - &invited_user, - RoomMemberEventContent::new(MembershipState::Invite), - &sender, - Some(invite_state), - true, - ) - .await?; + services().rooms.state_cache.update_membership( + &body.room_id, + &invited_user, + RoomMemberEventContent::new(MembershipState::Invite), + &sender, + Some(invite_state), + body.via.clone(), + true, + )?; } Ok(create_invite::v2::Response { diff --git a/src/database/key_value/rooms/state_cache.rs b/src/database/key_value/rooms/state_cache.rs index d37a6b9a..badf2875 100644 --- a/src/database/key_value/rooms/state_cache.rs +++ b/src/database/key_value/rooms/state_cache.rs @@ -1,5 +1,6 @@ use std::{collections::HashSet, sync::Arc}; +use itertools::Itertools; use ruma::{ events::{AnyStrippedStateEvent, AnySyncStateEvent}, serde::Raw, @@ -25,7 +26,11 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { } fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); + let roomid = room_id.as_bytes().to_vec(); + let mut roomid_prefix = room_id.as_bytes().to_vec(); + roomid_prefix.push(0xFF); + + let mut roomuser_id = roomid_prefix.clone(); roomuser_id.push(0xFF); roomuser_id.extend_from_slice(user_id.as_bytes()); @@ -40,11 +45,24 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; + if self + .roomuserid_joined + .scan_prefix(roomid_prefix.clone()) + .count() == 0 + && self + .roomuserid_invitecount + .scan_prefix(roomid_prefix) + .count() == 0 + { + self.roomid_inviteviaservers.remove(&roomid)?; + } + Ok(()) } fn mark_as_invited( &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, + invite_via: Option>, ) -> Result<()> { let mut roomuser_id = room_id.as_bytes().to_vec(); roomuser_id.push(0xFF); @@ -65,12 +83,31 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.userroomid_leftstate.remove(&userroom_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?; + if let Some(servers) = invite_via { + let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); + #[allow(clippy::redundant_clone)] // this is a necessary clone? + prev_servers.append(servers.clone().as_mut()); + let servers = prev_servers.iter().rev().unique().rev().collect_vec(); + + let servers = servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); + + self.roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers)?; + } + Ok(()) } fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> { - let mut roomuser_id = room_id.as_bytes().to_vec(); - roomuser_id.push(0xFF); + let roomid = room_id.as_bytes().to_vec(); + let mut roomid_prefix = room_id.as_bytes().to_vec(); + roomid_prefix.push(0xFF); + + let mut roomuser_id = roomid_prefix.clone(); roomuser_id.extend_from_slice(user_id.as_bytes()); let mut userroom_id = user_id.as_bytes().to_vec(); @@ -88,6 +125,18 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { self.userroomid_invitestate.remove(&userroom_id)?; self.roomuserid_invitecount.remove(&roomuser_id)?; + if self + .roomuserid_joined + .scan_prefix(roomid_prefix.clone()) + .count() == 0 + && self + .roomuserid_invitecount + .scan_prefix(roomid_prefix) + .count() == 0 + { + self.roomid_inviteviaservers.remove(&roomid)?; + } + Ok(()) } @@ -537,4 +586,38 @@ impl service::rooms::state_cache::Data for KeyValueDatabase { Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some()) } + + #[tracing::instrument(skip(self))] + fn servers_invite_via(&self, room_id: &RoomId) -> Result>> { + let room_id = room_id.as_bytes().to_vec(); + + self.roomid_inviteviaservers + .get(&room_id)? + .map(|servers| { + let state = serde_json::from_slice(&servers) + .map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?; + + Ok(state) + }) + .transpose() + } + + #[tracing::instrument(skip(self))] + fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { + let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); + prev_servers.append(servers.to_owned().as_mut()); + + let servers = prev_servers.iter().rev().unique().rev().collect_vec(); + + let servers = servers + .iter() + .map(|server| server.as_bytes()) + .collect_vec() + .join(&[0xFF][..]); + + self.roomid_inviteviaservers + .insert(room_id.as_bytes(), &servers)?; + + Ok(()) + } } diff --git a/src/database/mod.rs b/src/database/mod.rs index f4eef27a..6c656426 100644 --- a/src/database/mod.rs +++ b/src/database/mod.rs @@ -51,6 +51,8 @@ pub struct KeyValueDatabase { pub(super) global: Arc, pub(super) server_signingkeys: Arc, + pub(super) roomid_inviteviaservers: Arc, + //pub users: users::Users, pub(super) userid_password: Arc, pub(super) userid_displayname: Arc, @@ -342,6 +344,8 @@ impl KeyValueDatabase { global: builder.open_tree("global")?, server_signingkeys: builder.open_tree("server_signingkeys")?, + roomid_inviteviaservers: builder.open_tree("roomid_inviteviaservers")?, + auth_chain_cache: Mutex::new(LruCache::new( (f64::from(config.auth_chain_cache_capacity) * config.conduit_cache_capacity_modifier) as usize, )), diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 32713070..10ffda85 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -68,11 +68,15 @@ impl Service { continue; }; - services() - .rooms - .state_cache - .update_membership(room_id, &user_id, membership_event, &pdu.sender, None, false) - .await?; + services().rooms.state_cache.update_membership( + room_id, + &user_id, + membership_event, + &pdu.sender, + None, + None, + false, + )?; }, TimelineEventType::SpaceChild => { services() diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index 6e97396f..4cd9c8fa 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -17,6 +17,7 @@ pub trait Data: Send + Sync { fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; fn mark_as_invited( &self, user_id: &UserId, room_id: &RoomId, last_state: Option>>, + invite_via: Option>, ) -> Result<()>; fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()>; @@ -75,4 +76,12 @@ pub trait Data: Send + Sync { fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result; fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result; + + /// Gets the servers to either accept or decline invites via for a given + /// room. + fn servers_invite_via(&self, room_id: &RoomId) -> Result>>; + + /// Add the given servers the list to accept or decline invites via for a + /// given room. + fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()>; } diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 81f33f71..4c23824a 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -1,6 +1,7 @@ use std::{collections::HashSet, sync::Arc}; pub use data::Data; +use itertools::Itertools; use ruma::{ events::{ direct::DirectEvent, @@ -8,13 +9,15 @@ use ruma::{ room::{ create::RoomCreateEventContent, member::{MembershipState, RoomMemberEventContent}, + power_levels::RoomPowerLevelsEventContent, }, AnyStrippedStateEvent, AnySyncStateEvent, GlobalAccountDataEventType, RoomAccountDataEventType, StateEventType, }, + int, serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use tracing::warn; +use tracing::{error, warn}; use crate::{service::appservice::RegistrationInfo, services, Error, Result}; @@ -27,9 +30,11 @@ pub struct Service { impl Service { /// Update current membership data. #[tracing::instrument(skip(self, last_state))] - pub async fn update_membership( + #[allow(clippy::too_many_arguments)] + pub fn update_membership( &self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId, - last_state: Option>>, update_joined_count: bool, + last_state: Option>>, invite_via: Option>, + update_joined_count: bool, ) -> Result<()> { let membership = membership_event.membership; @@ -188,7 +193,8 @@ impl Service { return Ok(()); } - self.db.mark_as_invited(user_id, room_id, last_state)?; + self.db + .mark_as_invited(user_id, room_id, last_state, invite_via)?; }, MembershipState::Leave | MembershipState::Ban => { self.db.mark_as_left(user_id, room_id)?; @@ -344,4 +350,63 @@ impl Service { #[tracing::instrument(skip(self))] pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } + + #[tracing::instrument(skip(self))] + pub fn servers_invite_via(&self, room_id: &RoomId) -> Result>> { + self.db.servers_invite_via(room_id) + } + + /// Gets up to three servers that are likely to be in the room in the + /// distant future. + /// + /// See https://spec.matrix.org/v1.10/appendices/#routing + #[tracing::instrument(skip(self))] + pub fn servers_route_via(&self, room_id: &RoomId) -> Result> { + let most_powerful_user_server = services() + .rooms + .state_accessor + .room_state_get(room_id, &StateEventType::RoomPowerLevels, "")? + .map(|pdu| { + serde_json::from_str(pdu.content.get()).map(|conent: RoomPowerLevelsEventContent| { + conent + .users + .iter() + .max_by_key(|(_, power)| *power) + .and_then(|x| { + if x.1 >= &int!(50) { + Some(x) + } else { + None + } + }) + .map(|(user, _power)| user.server_name().to_owned()) + }) + }) + .transpose() + .map_err(|e| { + error!("Invalid power levels event content in database: {e}"); + Error::bad_database("Invalid power levels event content in database") + })? + .flatten(); + + let mut servers: Vec = services() + .rooms + .state_cache + .room_members(room_id) + .filter_map(Result::ok) + .counts_by(|user| user.server_name().to_owned()) + .iter() + .sorted_by_key(|(_, users)| *users) + .map(|(server, _)| server.to_owned()) + .rev() + .take(3) + .collect_vec(); + + if let Some(server) = most_powerful_user_server { + servers.insert(0, server); + servers.truncate(3); + } + + Ok(servers) + } } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index d136a4f1..89b1b009 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -481,11 +481,15 @@ impl Service { // Update our membership info, we do this here incase a user is invited // and immediately leaves we need the DB to record the invite event for auth - services() - .rooms - .state_cache - .update_membership(&pdu.room_id, &target_user_id, content, &pdu.sender, invite_state, true) - .await?; + services().rooms.state_cache.update_membership( + &pdu.room_id, + &target_user_id, + content, + &pdu.sender, + invite_state, + None, + true, + )?; } }, TimelineEventType::RoomMessage => {