fix broken federated room invites/joins

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-06-10 14:53:26 -04:00
parent f0557e3303
commit cb03654dc1
3 changed files with 97 additions and 79 deletions

View file

@ -152,8 +152,11 @@ pub(crate) async fn join_room_by_id_route(
let mut servers = services() let mut servers = services()
.rooms .rooms
.state_cache .state_cache
.servers_invite_via(&body.room_id)? .servers_invite_via(&body.room_id)
.unwrap_or( .filter_map(Result::ok)
.collect::<Vec<_>>();
servers.extend(
services() services()
.rooms .rooms
.state_cache .state_cache
@ -164,8 +167,7 @@ pub(crate) async fn join_room_by_id_route(
.filter_map(|event: serde_json::Value| event.get("sender").cloned()) .filter_map(|event: serde_json::Value| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) .filter_map(|sender| sender.as_str().map(ToOwned::to_owned))
.filter_map(|sender| UserId::parse(sender).ok()) .filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned()) .map(|user| user.server_name().to_owned()),
.collect::<Vec<_>>(),
); );
if let Some(server) = body.room_id.server_name() { if let Some(server) = body.room_id.server_name() {
@ -206,8 +208,11 @@ pub(crate) async fn join_room_by_id_or_alias_route(
services() services()
.rooms .rooms
.state_cache .state_cache
.servers_invite_via(&room_id)? .servers_invite_via(&room_id)
.unwrap_or( .filter_map(Result::ok),
);
servers.extend(
services() services()
.rooms .rooms
.state_cache .state_cache
@ -218,9 +223,7 @@ pub(crate) async fn join_room_by_id_or_alias_route(
.filter_map(|event: serde_json::Value| event.get("sender").cloned()) .filter_map(|event: serde_json::Value| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) .filter_map(|sender| sender.as_str().map(ToOwned::to_owned))
.filter_map(|sender| UserId::parse(sender).ok()) .filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned()) .map(|user| user.server_name().to_owned()),
.collect(),
),
); );
if let Some(server) = room_id.server_name() { if let Some(server) = room_id.server_name() {
@ -240,8 +243,11 @@ pub(crate) async fn join_room_by_id_or_alias_route(
services() services()
.rooms .rooms
.state_cache .state_cache
.servers_invite_via(&response.room_id)? .servers_invite_via(&response.room_id)
.unwrap_or( .filter_map(Result::ok),
);
servers.extend(
services() services()
.rooms .rooms
.state_cache .state_cache
@ -252,9 +258,7 @@ pub(crate) async fn join_room_by_id_or_alias_route(
.filter_map(|event: serde_json::Value| event.get("sender").cloned()) .filter_map(|event: serde_json::Value| event.get("sender").cloned())
.filter_map(|sender| sender.as_str().map(ToOwned::to_owned)) .filter_map(|sender| sender.as_str().map(ToOwned::to_owned))
.filter_map(|sender| UserId::parse(sender).ok()) .filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned()) .map(|user| user.server_name().to_owned()),
.collect(),
),
); );
(servers, response.room_id) (servers, response.room_id)
@ -1680,11 +1684,14 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> {
.invite_state(user_id, room_id)? .invite_state(user_id, room_id)?
.ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?; .ok_or(Error::BadRequest(ErrorKind::BadState, "User is not invited."))?;
let servers: HashSet<OwnedServerName> = services() let mut servers: HashSet<OwnedServerName> = services()
.rooms .rooms
.state_cache .state_cache
.servers_invite_via(room_id)? .servers_invite_via(room_id)
.map_or( .filter_map(Result::ok)
.collect();
servers.extend(
invite_state invite_state
.iter() .iter()
.filter_map(|event| serde_json::from_str(event.json().get()).ok()) .filter_map(|event| serde_json::from_str(event.json().get()).ok())
@ -1693,7 +1700,6 @@ async fn remote_leave_room(user_id: &UserId, room_id: &RoomId) -> Result<()> {
.filter_map(|sender| UserId::parse(sender).ok()) .filter_map(|sender| UserId::parse(sender).ok())
.map(|user| user.server_name().to_owned()) .map(|user| user.server_name().to_owned())
.collect::<HashSet<OwnedServerName>>(), .collect::<HashSet<OwnedServerName>>(),
HashSet::from_iter,
); );
debug!("servers in remote_leave_room: {servers:?}"); debug!("servers in remote_leave_room: {servers:?}");

View file

@ -6,7 +6,6 @@ use ruma::{
serde::Raw, serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
}; };
use tracing::error;
use crate::{ use crate::{
appservice::RegistrationInfo, appservice::RegistrationInfo,
@ -94,7 +93,7 @@ pub trait Data: Send + Sync {
/// Gets the servers to either accept or decline invites via for a given /// Gets the servers to either accept or decline invites via for a given
/// room. /// room.
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>>; fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a>;
/// Add the given servers the list to accept or decline invites via for a /// Add the given servers the list to accept or decline invites via for a
/// given room. /// given room.
@ -159,7 +158,10 @@ impl Data for KeyValueDatabase {
self.roomuserid_leftcount.remove(&roomuser_id)?; self.roomuserid_leftcount.remove(&roomuser_id)?;
if let Some(servers) = invite_via { if let Some(servers) = invite_via {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); let mut prev_servers = self
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect_vec();
#[allow(clippy::redundant_clone)] // this is a necessary clone? #[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut()); prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec(); let servers = prev_servers.iter().rev().unique().rev().collect_vec();
@ -639,30 +641,40 @@ impl Data for KeyValueDatabase {
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> { fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let key = room_id.as_bytes().to_vec(); let key = room_id.as_bytes().to_vec();
Box::new(
self.roomid_inviteviaservers self.roomid_inviteviaservers
.get(&key)? .scan_prefix(key)
.map(|servers| { .map(|(_, servers)| {
let state = serde_json::from_slice(&servers).map_err(|e| { ServerName::parse(
error!("Invalid state in userroomid_leftstate: {e}"); utils::string_from_bytes(
Error::bad_database("Invalid state in userroomid_leftstate.") servers
})?; .rsplit(|&b| b == 0xFF)
.next()
Ok(state) .expect("rsplit always returns an element"),
}) )
.transpose() .map_err(|_| {
Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid."))
}),
)
} }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> { fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new()); let mut prev_servers = self
prev_servers.append(servers.to_owned().as_mut()); .servers_invite_via(room_id)
.filter_map(Result::ok)
.collect_vec();
prev_servers.extend(servers.to_owned());
prev_servers.sort_unstable();
prev_servers.dedup();
let servers = prev_servers.iter().rev().unique().rev().collect_vec(); let servers = prev_servers
let servers = servers
.iter() .iter()
.map(|server| server.as_bytes()) .map(|server| server.as_bytes())
.collect_vec() .collect_vec()

View file

@ -377,7 +377,7 @@ impl Service {
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_left(user_id, room_id) } pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_left(user_id, room_id) }
#[tracing::instrument(skip(self))] #[tracing::instrument(skip(self))]
pub fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> { pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ {
self.db.servers_invite_via(room_id) self.db.servers_invite_via(room_id)
} }