diff --git a/src/api/client/space.rs b/src/api/client/space.rs index 409c9083..8f54de2a 100644 --- a/src/api/client/space.rs +++ b/src/api/client/space.rs @@ -1,9 +1,15 @@ -use std::str::FromStr; +use std::{collections::VecDeque, str::FromStr}; use axum::extract::State; +use conduwuit::{checked, pdu::ShortRoomId, utils::stream::IterStream}; +use futures::{StreamExt, TryFutureExt}; use ruma::{ api::client::{error::ErrorKind, space::get_hierarchy}, - UInt, + OwnedRoomId, OwnedServerName, RoomId, UInt, UserId, +}; +use service::{ + rooms::spaces::{get_parent_children_via, summary_to_chunk, SummaryAccessibility}, + Services, }; use crate::{service::rooms::spaces::PaginationToken, Error, Result, Ruma}; @@ -16,8 +22,6 @@ pub(crate) async fn get_hierarchy_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - let limit = body .limit .unwrap_or_else(|| UInt::from(10_u32)) @@ -43,16 +47,163 @@ pub(crate) async fn get_hierarchy_route( } } - services - .rooms - .spaces - .get_client_hierarchy( - sender_user, - &body.room_id, - limit.try_into().unwrap_or(10), - key.map_or(vec![], |token| token.short_room_ids), - max_depth.into(), - body.suggested_only, - ) - .await + get_client_hierarchy( + &services, + body.sender_user(), + &body.room_id, + limit.try_into().unwrap_or(10), + key.map_or(vec![], |token| token.short_room_ids), + max_depth.into(), + body.suggested_only, + ) + .await +} + +async fn get_client_hierarchy( + services: &Services, + sender_user: &UserId, + room_id: &RoomId, + limit: usize, + short_room_ids: Vec, + max_depth: u64, + suggested_only: bool, +) -> Result { + let mut parents = VecDeque::new(); + + // Don't start populating the results if we have to start at a specific room. + let mut populate_results = short_room_ids.is_empty(); + + let mut stack = vec![vec![(room_id.to_owned(), match room_id.server_name() { + | Some(server_name) => vec![server_name.into()], + | None => vec![], + })]]; + + let mut results = Vec::with_capacity(limit); + + while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) } { + if results.len() >= limit { + break; + } + + match ( + services + .rooms + .spaces + .get_summary_and_children_client(¤t_room, suggested_only, sender_user, &via) + .await?, + current_room == room_id, + ) { + | (Some(SummaryAccessibility::Accessible(summary)), _) => { + let mut children: Vec<(OwnedRoomId, Vec)> = + get_parent_children_via(&summary, suggested_only) + .into_iter() + .filter(|(room, _)| parents.iter().all(|parent| parent != room)) + .rev() + .collect(); + + if populate_results { + results.push(summary_to_chunk(*summary.clone())); + } else { + children = children + .iter() + .rev() + .stream() + .skip_while(|(room, _)| { + services + .rooms + .short + .get_shortroomid(room) + .map_ok(|short| Some(&short) != short_room_ids.get(parents.len())) + .unwrap_or_else(|_| false) + }) + .map(Clone::clone) + .collect::)>>() + .await + .into_iter() + .rev() + .collect(); + + if children.is_empty() { + return Err(Error::BadRequest( + ErrorKind::InvalidParam, + "Room IDs in token were not found.", + )); + } + + // We have reached the room after where we last left off + let parents_len = parents.len(); + if checked!(parents_len + 1)? == short_room_ids.len() { + populate_results = true; + } + } + + let parents_len: u64 = parents.len().try_into()?; + if !children.is_empty() && parents_len < max_depth { + parents.push_back(current_room.clone()); + stack.push(children); + } + // Root room in the space hierarchy, we return an error + // if this one fails. + }, + | (Some(SummaryAccessibility::Inaccessible), true) => { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "The requested room is inaccessible", + )); + }, + | (None, true) => { + return Err(Error::BadRequest( + ErrorKind::forbidden(), + "The requested room was not found", + )); + }, + // Just ignore other unavailable rooms + | (None | Some(SummaryAccessibility::Inaccessible), false) => (), + } + } + + Ok(get_hierarchy::v1::Response { + next_batch: if let Some((room, _)) = next_room_to_traverse(&mut stack, &mut parents) { + parents.pop_front(); + parents.push_back(room); + + let next_short_room_ids: Vec<_> = parents + .iter() + .stream() + .filter_map(|room_id| async move { + services.rooms.short.get_shortroomid(room_id).await.ok() + }) + .collect() + .await; + + (next_short_room_ids != short_room_ids && !next_short_room_ids.is_empty()).then( + || { + PaginationToken { + short_room_ids: next_short_room_ids, + limit: UInt::new(max_depth) + .expect("When sent in request it must have been valid UInt"), + max_depth: UInt::new(max_depth) + .expect("When sent in request it must have been valid UInt"), + suggested_only, + } + .to_string() + }, + ) + } else { + None + }, + rooms: results, + }) +} + +fn next_room_to_traverse( + stack: &mut Vec)>>, + parents: &mut VecDeque, +) -> Option<(OwnedRoomId, Vec)> { + while stack.last().is_some_and(Vec::is_empty) { + stack.pop(); + parents.pop_back(); + } + + stack.last_mut().and_then(Vec::pop) } diff --git a/src/api/server/hierarchy.rs b/src/api/server/hierarchy.rs index a10df6ac..bcf2f7bc 100644 --- a/src/api/server/hierarchy.rs +++ b/src/api/server/hierarchy.rs @@ -1,7 +1,12 @@ use axum::extract::State; -use ruma::api::{client::error::ErrorKind, federation::space::get_hierarchy}; +use conduwuit::{Err, Result}; +use ruma::{api::federation::space::get_hierarchy, RoomId, ServerName}; +use service::{ + rooms::spaces::{get_parent_children_via, Identifier, SummaryAccessibility}, + Services, +}; -use crate::{Error, Result, Ruma}; +use crate::Ruma; /// # `GET /_matrix/federation/v1/hierarchy/{roomId}` /// @@ -11,13 +16,58 @@ pub(crate) async fn get_hierarchy_route( State(services): State, body: Ruma, ) -> Result { - if services.rooms.metadata.exists(&body.room_id).await { - services - .rooms - .spaces - .get_federation_hierarchy(&body.room_id, body.origin(), body.suggested_only) - .await - } else { - Err(Error::BadRequest(ErrorKind::NotFound, "Room does not exist.")) + if !services.rooms.metadata.exists(&body.room_id).await { + return Err!(Request(NotFound("Room does not exist."))); + } + + get_hierarchy(&services, &body.room_id, body.origin(), body.suggested_only).await +} + +/// Gets the response for the space hierarchy over federation request +/// +/// Errors if the room does not exist, so a check if the room exists should +/// be done +async fn get_hierarchy( + services: &Services, + room_id: &RoomId, + server_name: &ServerName, + suggested_only: bool, +) -> Result { + match services + .rooms + .spaces + .get_summary_and_children_local(&room_id.to_owned(), Identifier::ServerName(server_name)) + .await? + { + | Some(SummaryAccessibility::Accessible(room)) => { + let mut children = Vec::new(); + let mut inaccessible_children = Vec::new(); + + for (child, _via) in get_parent_children_via(&room, suggested_only) { + match services + .rooms + .spaces + .get_summary_and_children_local(&child, Identifier::ServerName(server_name)) + .await? + { + | Some(SummaryAccessibility::Accessible(summary)) => { + children.push((*summary).into()); + }, + | Some(SummaryAccessibility::Inaccessible) => { + inaccessible_children.push(child); + }, + | None => (), + } + } + + Ok(get_hierarchy::v1::Response { + room: *room, + children, + inaccessible_children, + }) + }, + | Some(SummaryAccessibility::Inaccessible) => + Err!(Request(NotFound("The requested room is inaccessible"))), + | None => Err!(Request(NotFound("The requested room was not found"))), } } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 11794752..1e2b0a9f 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -1,22 +1,14 @@ +mod pagination_token; mod tests; -use std::{ - collections::{HashMap, VecDeque}, - fmt::{Display, Formatter}, - str::FromStr, - sync::Arc, -}; +use std::{collections::HashMap, sync::Arc}; -use conduwuit::{ - checked, debug_info, err, - utils::{math::usize_from_f64, IterStream}, - Error, Result, -}; -use futures::{StreamExt, TryFutureExt}; +use conduwuit::{debug_info, err, utils::math::usize_from_f64, Error, Result}; +use futures::StreamExt; use lru_cache::LruCache; use ruma::{ api::{ - client::{self, error::ErrorKind, space::SpaceHierarchyRoomsChunk}, + client::{error::ErrorKind, space::SpaceHierarchyRoomsChunk}, federation::{ self, space::{SpaceHierarchyChildSummary, SpaceHierarchyParentSummary}, @@ -29,11 +21,12 @@ use ruma::{ }, serde::Raw, space::SpaceRoomJoinRule, - OwnedRoomId, OwnedServerName, RoomId, ServerName, UInt, UserId, + OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId, }; use tokio::sync::Mutex; -use crate::{rooms, rooms::short::ShortRoomId, sending, Dep}; +pub use self::pagination_token::PaginationToken; +use crate::{rooms, sending, Dep}; pub struct CachedSpaceHierarchySummary { summary: SpaceHierarchyParentSummary, @@ -44,81 +37,10 @@ pub enum SummaryAccessibility { Inaccessible, } -// TODO: perhaps use some better form of token rather than just room count -#[derive(Debug, Eq, PartialEq)] -pub struct PaginationToken { - /// Path down the hierarchy of the room to start the response at, - /// excluding the root space. - pub short_room_ids: Vec, - pub limit: UInt, - pub max_depth: UInt, - pub suggested_only: bool, -} - -impl FromStr for PaginationToken { - type Err = Error; - - fn from_str(value: &str) -> Result { - let mut values = value.split('_'); - - let mut pag_tok = || { - let rooms = values - .next()? - .split(',') - .filter_map(|room_s| u64::from_str(room_s).ok()) - .collect(); - - Some(Self { - short_room_ids: rooms, - limit: UInt::from_str(values.next()?).ok()?, - max_depth: UInt::from_str(values.next()?).ok()?, - suggested_only: { - let slice = values.next()?; - - if values.next().is_none() { - if slice == "true" { - true - } else if slice == "false" { - false - } else { - None? - } - } else { - None? - } - }, - }) - }; - - if let Some(token) = pag_tok() { - Ok(token) - } else { - Err(Error::BadRequest(ErrorKind::InvalidParam, "invalid token")) - } - } -} - -impl Display for PaginationToken { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!( - f, - "{}_{}_{}_{}", - self.short_room_ids - .iter() - .map(ToString::to_string) - .collect::>() - .join(","), - self.limit, - self.max_depth, - self.suggested_only - ) - } -} - /// Identifier used to check if rooms are accessible /// /// None is used if you want to return the room, no matter if accessible or not -enum Identifier<'a> { +pub enum Identifier<'a> { UserId(&'a UserId), ServerName(&'a ServerName), } @@ -164,60 +86,8 @@ impl crate::Service for Service { } impl Service { - /// Gets the response for the space hierarchy over federation request - /// - /// Errors if the room does not exist, so a check if the room exists should - /// be done - pub async fn get_federation_hierarchy( - &self, - room_id: &RoomId, - server_name: &ServerName, - suggested_only: bool, - ) -> Result { - match self - .get_summary_and_children_local( - &room_id.to_owned(), - Identifier::ServerName(server_name), - ) - .await? - { - | Some(SummaryAccessibility::Accessible(room)) => { - let mut children = Vec::new(); - let mut inaccessible_children = Vec::new(); - - for (child, _via) in get_parent_children_via(&room, suggested_only) { - match self - .get_summary_and_children_local( - &child, - Identifier::ServerName(server_name), - ) - .await? - { - | Some(SummaryAccessibility::Accessible(summary)) => { - children.push((*summary).into()); - }, - | Some(SummaryAccessibility::Inaccessible) => { - inaccessible_children.push(child); - }, - | None => (), - } - } - - Ok(federation::space::get_hierarchy::v1::Response { - room: *room, - children, - inaccessible_children, - }) - }, - | Some(SummaryAccessibility::Inaccessible) => - Err(Error::BadRequest(ErrorKind::NotFound, "The requested room is inaccessible")), - | None => - Err(Error::BadRequest(ErrorKind::NotFound, "The requested room was not found")), - } - } - /// Gets the summary of a space using solely local information - async fn get_summary_and_children_local( + pub async fn get_summary_and_children_local( &self, current_room: &OwnedRoomId, identifier: Identifier<'_>, @@ -366,7 +236,7 @@ impl Service { /// Gets the summary of a space using either local or remote (federation) /// sources - async fn get_summary_and_children_client( + pub async fn get_summary_and_children_client( &self, current_room: &OwnedRoomId, suggested_only: bool, @@ -470,147 +340,6 @@ impl Service { }) } - pub async fn get_client_hierarchy( - &self, - sender_user: &UserId, - room_id: &RoomId, - limit: usize, - short_room_ids: Vec, - max_depth: u64, - suggested_only: bool, - ) -> Result { - let mut parents = VecDeque::new(); - - // Don't start populating the results if we have to start at a specific room. - let mut populate_results = short_room_ids.is_empty(); - - let mut stack = vec![vec![(room_id.to_owned(), match room_id.server_name() { - | Some(server_name) => vec![server_name.into()], - | None => vec![], - })]]; - - let mut results = Vec::with_capacity(limit); - - while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) } - { - if results.len() >= limit { - break; - } - - match ( - self.get_summary_and_children_client( - ¤t_room, - suggested_only, - sender_user, - &via, - ) - .await?, - current_room == room_id, - ) { - | (Some(SummaryAccessibility::Accessible(summary)), _) => { - let mut children: Vec<(OwnedRoomId, Vec)> = - get_parent_children_via(&summary, suggested_only) - .into_iter() - .filter(|(room, _)| parents.iter().all(|parent| parent != room)) - .rev() - .collect(); - - if populate_results { - results.push(summary_to_chunk(*summary.clone())); - } else { - children = children - .iter() - .rev() - .stream() - .skip_while(|(room, _)| { - self.services - .short - .get_shortroomid(room) - .map_ok(|short| { - Some(&short) != short_room_ids.get(parents.len()) - }) - .unwrap_or_else(|_| false) - }) - .map(Clone::clone) - .collect::)>>() - .await - .into_iter() - .rev() - .collect(); - - if children.is_empty() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room IDs in token were not found.", - )); - } - - // We have reached the room after where we last left off - let parents_len = parents.len(); - if checked!(parents_len + 1)? == short_room_ids.len() { - populate_results = true; - } - } - - let parents_len: u64 = parents.len().try_into()?; - if !children.is_empty() && parents_len < max_depth { - parents.push_back(current_room.clone()); - stack.push(children); - } - // Root room in the space hierarchy, we return an error - // if this one fails. - }, - | (Some(SummaryAccessibility::Inaccessible), true) => { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "The requested room is inaccessible", - )); - }, - | (None, true) => { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "The requested room was not found", - )); - }, - // Just ignore other unavailable rooms - | (None | Some(SummaryAccessibility::Inaccessible), false) => (), - } - } - - Ok(client::space::get_hierarchy::v1::Response { - next_batch: if let Some((room, _)) = next_room_to_traverse(&mut stack, &mut parents) { - parents.pop_front(); - parents.push_back(room); - - let next_short_room_ids: Vec<_> = parents - .iter() - .stream() - .filter_map(|room_id| async move { - self.services.short.get_shortroomid(room_id).await.ok() - }) - .collect() - .await; - - (next_short_room_ids != short_room_ids && !next_short_room_ids.is_empty()).then( - || { - PaginationToken { - short_room_ids: next_short_room_ids, - limit: UInt::new(max_depth) - .expect("When sent in request it must have been valid UInt"), - max_depth: UInt::new(max_depth) - .expect("When sent in request it must have been valid UInt"), - suggested_only, - } - .to_string() - }, - ) - } else { - None - }, - rooms: results, - }) - } - /// Simply returns the stripped m.space.child events of a room async fn get_stripped_space_child_events( &self, @@ -757,7 +486,8 @@ impl From for SpaceHierarchyRoomsChunk { /// Here because cannot implement `From` across ruma-federation-api and /// ruma-client-api types -fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyRoomsChunk { +#[must_use] +pub fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyRoomsChunk { let SpaceHierarchyParentSummary { canonical_alias, name, @@ -790,7 +520,8 @@ fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyRooms /// Returns the children of a SpaceHierarchyParentSummary, making use of the /// children_state field -fn get_parent_children_via( +#[must_use] +pub fn get_parent_children_via( parent: &SpaceHierarchyParentSummary, suggested_only: bool, ) -> Vec<(OwnedRoomId, Vec)> { @@ -808,15 +539,3 @@ fn get_parent_children_via( }) .collect() } - -fn next_room_to_traverse( - stack: &mut Vec)>>, - parents: &mut VecDeque, -) -> Option<(OwnedRoomId, Vec)> { - while stack.last().is_some_and(Vec::is_empty) { - stack.pop(); - parents.pop_back(); - } - - stack.last_mut().and_then(Vec::pop) -} diff --git a/src/service/rooms/spaces/pagination_token.rs b/src/service/rooms/spaces/pagination_token.rs new file mode 100644 index 00000000..8f019e8d --- /dev/null +++ b/src/service/rooms/spaces/pagination_token.rs @@ -0,0 +1,76 @@ +use std::{ + fmt::{Display, Formatter}, + str::FromStr, +}; + +use conduwuit::{Error, Result}; +use ruma::{api::client::error::ErrorKind, UInt}; + +use crate::rooms::short::ShortRoomId; + +// TODO: perhaps use some better form of token rather than just room count +#[derive(Debug, Eq, PartialEq)] +pub struct PaginationToken { + /// Path down the hierarchy of the room to start the response at, + /// excluding the root space. + pub short_room_ids: Vec, + pub limit: UInt, + pub max_depth: UInt, + pub suggested_only: bool, +} + +impl FromStr for PaginationToken { + type Err = Error; + + fn from_str(value: &str) -> Result { + let mut values = value.split('_'); + let mut pag_tok = || { + let short_room_ids = values + .next()? + .split(',') + .filter_map(|room_s| u64::from_str(room_s).ok()) + .collect(); + + let limit = UInt::from_str(values.next()?).ok()?; + let max_depth = UInt::from_str(values.next()?).ok()?; + let slice = values.next()?; + let suggested_only = if values.next().is_none() { + if slice == "true" { + true + } else if slice == "false" { + false + } else { + None? + } + } else { + None? + }; + + Some(Self { + short_room_ids, + limit, + max_depth, + suggested_only, + }) + }; + + if let Some(token) = pag_tok() { + Ok(token) + } else { + Err(Error::BadRequest(ErrorKind::InvalidParam, "invalid token")) + } + } +} + +impl Display for PaginationToken { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + let short_room_ids = self + .short_room_ids + .iter() + .map(ToString::to_string) + .collect::>() + .join(","); + + write!(f, "{short_room_ids}_{}_{}_{}", self.limit, self.max_depth, self.suggested_only) + } +}