diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index f59da7b7..2d04be2f 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -152,21 +152,23 @@ pub(crate) async fn join_room_by_id_route( let mut servers = services() .rooms .state_cache - .servers_invite_via(&body.room_id)? - .unwrap_or( - services() - .rooms - .state_cache - .invite_state(sender_user, &body.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect::>(), - ); + .servers_invite_via(&body.room_id) + .filter_map(Result::ok) + .collect::>(); + + servers.extend( + services() + .rooms + .state_cache + .invite_state(sender_user, &body.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), + ); if let Some(server) = body.room_id.server_name() { servers.push(server.into()); @@ -206,21 +208,22 @@ pub(crate) async fn join_room_by_id_or_alias_route( services() .rooms .state_cache - .servers_invite_via(&room_id)? - .unwrap_or( - services() - .rooms - .state_cache - .invite_state(sender_user, &room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(), - ), + .servers_invite_via(&room_id) + .filter_map(Result::ok), + ); + + servers.extend( + services() + .rooms + .state_cache + .invite_state(sender_user, &room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), ); if let Some(server) = room_id.server_name() { @@ -240,21 +243,22 @@ pub(crate) async fn join_room_by_id_or_alias_route( services() .rooms .state_cache - .servers_invite_via(&response.room_id)? - .unwrap_or( - services() - .rooms - .state_cache - .invite_state(sender_user, &response.room_id)? - .unwrap_or_default() - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect(), - ), + .servers_invite_via(&response.room_id) + .filter_map(Result::ok), + ); + + servers.extend( + services() + .rooms + .state_cache + .invite_state(sender_user, &response.room_id)? + .unwrap_or_default() + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()), ); (servers, response.room_id) @@ -1680,21 +1684,23 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> { .invite_state(user_id, room_id)? .ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; - let servers: HashSet = services() + let mut servers: HashSet = services() .rooms .state_cache - .servers_invite_via(room_id)? - .map_or( - invite_state - .iter() - .filter_map(|event| serde_json::from_str(event.json().get()).ok()) - .filter_map(|event: serde_json::Value| event.get("sender").cloned()) - .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) - .filter_map(|sender| UserId::parse(sender).ok()) - .map(|user| user.server_name().to_owned()) - .collect::>(), - HashSet::from_iter, - ); + .servers_invite_via(room_id) + .filter_map(Result::ok) + .collect(); + + servers.extend( + invite_state + .iter() + .filter_map(|event| serde_json::from_str(event.json().get()).ok()) + .filter_map(|event: serde_json::Value| event.get("sender").cloned()) + .filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) + .filter_map(|sender| UserId::parse(sender).ok()) + .map(|user| user.server_name().to_owned()) + .collect::>(), + ); debug!("servers in remote_leave_room: {servers:?}"); diff --git a/src/service/rooms/state_cache/data.rs b/src/service/rooms/state_cache/data.rs index acf390d9..d380c1ab 100644 --- a/src/service/rooms/state_cache/data.rs +++ b/src/service/rooms/state_cache/data.rs @@ -6,7 +6,6 @@ use ruma::{ serde::Raw, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, }; -use tracing::error; use crate::{ appservice::RegistrationInfo, @@ -94,7 +93,7 @@ pub trait Data: Send + Sync { /// Gets the servers to either accept or decline invites via for a given /// room. - fn servers_invite_via(&self, room_id: &RoomId) -> Result>>; + fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box> + 'a>; /// Add the given servers the list to accept or decline invites via for a /// given room. @@ -159,7 +158,10 @@ impl Data for KeyValueDatabase { self.roomuserid_leftcount.remove(&roomuser_id)?; if let Some(servers) = invite_via { - let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); + let mut prev_servers = self + .servers_invite_via(room_id) + .filter_map(Result::ok) + .collect_vec(); #[allow(clippy::redundant_clone)] // this is a necessary clone? prev_servers.append(servers.clone().as_mut()); let servers = prev_servers.iter().rev().unique().rev().collect_vec(); @@ -639,30 +641,40 @@ impl Data for KeyValueDatabase { } #[tracing::instrument(skip(self))] - fn servers_invite_via(&self, room_id: &RoomId) -> Result>> { + fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box> + 'a> { let key = room_id.as_bytes().to_vec(); - self.roomid_inviteviaservers - .get(&key)? - .map(|servers| { - let state = serde_json::from_slice(&servers).map_err(|e| { - error!("Invalid state in userroomid_leftstate: {e}"); - Error::bad_database("Invalid state in userroomid_leftstate.") - })?; - - Ok(state) - }) - .transpose() + Box::new( + self.roomid_inviteviaservers + .scan_prefix(key) + .map(|(_, servers)| { + ServerName::parse( + utils::string_from_bytes( + servers + .rsplit(|&b| b == 0xFF) + .next() + .expect("rsplit always returns an element"), + ) + .map_err(|_| { + Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.") + })?, + ) + .map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid.")) + }), + ) } #[tracing::instrument(skip(self))] fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { - let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); - prev_servers.append(servers.to_owned().as_mut()); + let mut prev_servers = self + .servers_invite_via(room_id) + .filter_map(Result::ok) + .collect_vec(); + prev_servers.extend(servers.to_owned()); + prev_servers.sort_unstable(); + prev_servers.dedup(); - let servers = prev_servers.iter().rev().unique().rev().collect_vec(); - - let servers = servers + let servers = prev_servers .iter() .map(|server| server.as_bytes()) .collect_vec() diff --git a/src/service/rooms/state_cache/mod.rs b/src/service/rooms/state_cache/mod.rs index 2b3f5fb7..c9d1bbd7 100644 --- a/src/service/rooms/state_cache/mod.rs +++ b/src/service/rooms/state_cache/mod.rs @@ -377,7 +377,7 @@ impl Service { pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result { self.db.is_left(user_id, room_id) } #[tracing::instrument(skip(self))] - pub fn servers_invite_via(&self, room_id: &RoomId) -> Result>> { + pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator> + '_ { self.db.servers_invite_via(room_id) }