fix broken federated room invites/joins

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-06-10 14:53:26 -04:00
parent f0557e3303
commit cb03654dc1
3 changed files with 97 additions and 79 deletions

View file

@ -6,7 +6,6 @@ use ruma::{
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
};
use tracing::error;
use crate::{
appservice::RegistrationInfo,
@ -94,7 +93,7 @@ pub trait Data: Send + Sync {
/// Gets the servers to either accept or decline invites via for a given
/// room.
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>>;
fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a>;
/// Add the given servers the list to accept or decline invites via for a
/// given room.
@ -159,7 +158,10 @@ impl Data for KeyValueDatabase {
self.roomuserid_leftcount.remove(&roomuser_id)?;
if let Some(servers) = invite_via {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
let mut prev_servers = self
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect_vec();
#[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
@ -639,30 +641,40 @@ impl Data for KeyValueDatabase {
}
#[tracing::instrument(skip(self))]
fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let key = room_id.as_bytes().to_vec();
self.roomid_inviteviaservers
.get(&key)?
.map(|servers| {
let state = serde_json::from_slice(&servers).map_err(|e| {
error!("Invalid state in userroomid_leftstate: {e}");
Error::bad_database("Invalid state in userroomid_leftstate.")
})?;
Ok(state)
})
.transpose()
Box::new(
self.roomid_inviteviaservers
.scan_prefix(key)
.map(|(_, servers)| {
ServerName::parse(
utils::string_from_bytes(
servers
.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid."))
}),
)
}
#[tracing::instrument(skip(self))]
fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
let mut prev_servers = self.servers_invite_via(room_id)?.unwrap_or(Vec::new());
prev_servers.append(servers.to_owned().as_mut());
let mut prev_servers = self
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect_vec();
prev_servers.extend(servers.to_owned());
prev_servers.sort_unstable();
prev_servers.dedup();
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
let servers = prev_servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()

View file

@ -377,7 +377,7 @@ impl Service {
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_left(user_id, room_id) }
#[tracing::instrument(skip(self))]
pub fn servers_invite_via(&self, room_id: &RoomId) -> Result<Option<Vec<OwnedServerName>>> {
pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ {
self.db.servers_invite_via(room_id)
}