diff --git a/src/api/client/space.rs b/src/api/client/space.rs index 8f54de2a..7efd7817 100644 --- a/src/api/client/space.rs +++ b/src/api/client/space.rs @@ -1,18 +1,25 @@ -use std::{collections::VecDeque, str::FromStr}; +use std::{ + collections::{BTreeSet, VecDeque}, + str::FromStr, +}; use axum::extract::State; -use conduwuit::{checked, pdu::ShortRoomId, utils::stream::IterStream}; -use futures::{StreamExt, TryFutureExt}; +use conduwuit::{ + utils::{future::TryExtExt, stream::IterStream}, + Err, Result, +}; +use futures::{future::OptionFuture, StreamExt, TryFutureExt}; use ruma::{ - api::client::{error::ErrorKind, space::get_hierarchy}, - OwnedRoomId, OwnedServerName, RoomId, UInt, UserId, + api::client::space::get_hierarchy, OwnedRoomId, OwnedServerName, RoomId, UInt, UserId, }; use service::{ - rooms::spaces::{get_parent_children_via, summary_to_chunk, SummaryAccessibility}, + rooms::spaces::{ + get_parent_children_via, summary_to_chunk, PaginationToken, SummaryAccessibility, + }, Services, }; -use crate::{service::rooms::spaces::PaginationToken, Error, Result, Ruma}; +use crate::Ruma; /// # `GET /_matrix/client/v1/rooms/{room_id}/hierarchy` /// @@ -40,10 +47,9 @@ pub(crate) async fn get_hierarchy_route( // Should prevent unexpeded behaviour in (bad) clients if let Some(ref token) = key { if token.suggested_only != body.suggested_only || token.max_depth != max_depth { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "suggested_only and max_depth cannot change on paginated requests", - )); + return Err!(Request(InvalidParam( + "suggested_only and max_depth cannot change on paginated requests" + ))); } } @@ -52,58 +58,70 @@ pub(crate) async fn get_hierarchy_route( body.sender_user(), &body.room_id, limit.try_into().unwrap_or(10), - key.map_or(vec![], |token| token.short_room_ids), - max_depth.into(), + max_depth.try_into().unwrap_or(usize::MAX), body.suggested_only, + key.as_ref() + .into_iter() + .flat_map(|t| t.short_room_ids.iter()), ) .await } -async fn get_client_hierarchy( +async fn get_client_hierarchy<'a, ShortRoomIds>( services: &Services, sender_user: &UserId, room_id: &RoomId, limit: usize, - short_room_ids: Vec, - max_depth: u64, + max_depth: usize, suggested_only: bool, -) -> Result { - let mut parents = VecDeque::new(); + short_room_ids: ShortRoomIds, +) -> Result +where + ShortRoomIds: Iterator + Clone + Send + Sync + 'a, +{ + type Via = Vec; + type Entry = (OwnedRoomId, Via); + type Rooms = VecDeque; - // Don't start populating the results if we have to start at a specific room. - let mut populate_results = short_room_ids.is_empty(); + let mut queue: Rooms = [( + room_id.to_owned(), + room_id + .server_name() + .map(ToOwned::to_owned) + .into_iter() + .collect(), + )] + .into(); - let mut stack = vec![vec![(room_id.to_owned(), match room_id.server_name() { - | Some(server_name) => vec![server_name.into()], - | None => vec![], - })]]; + let mut rooms = Vec::with_capacity(limit); + let mut parents = BTreeSet::new(); + while let Some((current_room, via)) = queue.pop_front() { + let summary = services + .rooms + .spaces + .get_summary_and_children_client(¤t_room, suggested_only, sender_user, &via) + .await?; - let mut results = Vec::with_capacity(limit); - - while let Some((current_room, via)) = { next_room_to_traverse(&mut stack, &mut parents) } { - if results.len() >= limit { - break; - } - - match ( - services - .rooms - .spaces - .get_summary_and_children_client(¤t_room, suggested_only, sender_user, &via) - .await?, - current_room == room_id, - ) { + match (summary, current_room == room_id) { + | (None | Some(SummaryAccessibility::Inaccessible), false) => { + // Just ignore other unavailable rooms + }, + | (None, true) => { + return Err!(Request(Forbidden("The requested room was not found"))); + }, + | (Some(SummaryAccessibility::Inaccessible), true) => { + return Err!(Request(Forbidden("The requested room is inaccessible"))); + }, | (Some(SummaryAccessibility::Accessible(summary)), _) => { - let mut children: Vec<(OwnedRoomId, Vec)> = - get_parent_children_via(&summary, suggested_only) - .into_iter() - .filter(|(room, _)| parents.iter().all(|parent| parent != room)) - .rev() - .collect(); + let populate = parents.len() >= short_room_ids.clone().count(); - if populate_results { - results.push(summary_to_chunk(*summary.clone())); - } else { + let mut children: Vec = get_parent_children_via(&summary, suggested_only) + .filter(|(room, _)| !parents.contains(room)) + .rev() + .map(|(key, val)| (key, val.collect())) + .collect(); + + if !populate { children = children .iter() .rev() @@ -113,97 +131,69 @@ async fn get_client_hierarchy( .rooms .short .get_shortroomid(room) - .map_ok(|short| Some(&short) != short_room_ids.get(parents.len())) + .map_ok(|short| { + Some(&short) != short_room_ids.clone().nth(parents.len()) + }) .unwrap_or_else(|_| false) }) .map(Clone::clone) - .collect::)>>() + .collect::>() .await .into_iter() .rev() .collect(); - - if children.is_empty() { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Room IDs in token were not found.", - )); - } - - // We have reached the room after where we last left off - let parents_len = parents.len(); - if checked!(parents_len + 1)? == short_room_ids.len() { - populate_results = true; - } } - let parents_len: u64 = parents.len().try_into()?; - if !children.is_empty() && parents_len < max_depth { - parents.push_back(current_room.clone()); - stack.push(children); + if populate { + rooms.push(summary_to_chunk(summary.clone())); + } else if queue.is_empty() && children.is_empty() { + return Err!(Request(InvalidParam("Room IDs in token were not found."))); } - // Root room in the space hierarchy, we return an error - // if this one fails. + + parents.insert(current_room.clone()); + if rooms.len() >= limit { + break; + } + + if children.is_empty() { + break; + } + + if parents.len() >= max_depth { + continue; + } + + queue.extend(children); }, - | (Some(SummaryAccessibility::Inaccessible), true) => { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "The requested room is inaccessible", - )); - }, - | (None, true) => { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "The requested room was not found", - )); - }, - // Just ignore other unavailable rooms - | (None | Some(SummaryAccessibility::Inaccessible), false) => (), } } - Ok(get_hierarchy::v1::Response { - next_batch: if let Some((room, _)) = next_room_to_traverse(&mut stack, &mut parents) { - parents.pop_front(); - parents.push_back(room); + let next_batch: OptionFuture<_> = queue + .pop_front() + .map(|(room, _)| async move { + parents.insert(room); let next_short_room_ids: Vec<_> = parents .iter() .stream() - .filter_map(|room_id| async move { - services.rooms.short.get_shortroomid(room_id).await.ok() - }) + .filter_map(|room_id| services.rooms.short.get_shortroomid(room_id).ok()) .collect() .await; - (next_short_room_ids != short_room_ids && !next_short_room_ids.is_empty()).then( - || { - PaginationToken { - short_room_ids: next_short_room_ids, - limit: UInt::new(max_depth) - .expect("When sent in request it must have been valid UInt"), - max_depth: UInt::new(max_depth) - .expect("When sent in request it must have been valid UInt"), - suggested_only, - } - .to_string() - }, - ) - } else { - None - }, - rooms: results, + (next_short_room_ids.iter().ne(short_room_ids) && !next_short_room_ids.is_empty()) + .then_some(PaginationToken { + short_room_ids: next_short_room_ids, + limit: max_depth.try_into().ok()?, + max_depth: max_depth.try_into().ok()?, + suggested_only, + }) + .as_ref() + .map(PaginationToken::to_string) + }) + .into(); + + Ok(get_hierarchy::v1::Response { + next_batch: next_batch.await.flatten(), + rooms, }) } - -fn next_room_to_traverse( - stack: &mut Vec)>>, - parents: &mut VecDeque, -) -> Option<(OwnedRoomId, Vec)> { - while stack.last().is_some_and(Vec::is_empty) { - stack.pop(); - parents.pop_back(); - } - - stack.last_mut().and_then(Vec::pop) -} diff --git a/src/api/server/hierarchy.rs b/src/api/server/hierarchy.rs index bcf2f7bc..f7bc43ab 100644 --- a/src/api/server/hierarchy.rs +++ b/src/api/server/hierarchy.rs @@ -1,10 +1,11 @@ use axum::extract::State; -use conduwuit::{Err, Result}; -use ruma::{api::federation::space::get_hierarchy, RoomId, ServerName}; -use service::{ - rooms::spaces::{get_parent_children_via, Identifier, SummaryAccessibility}, - Services, +use conduwuit::{ + utils::stream::{BroadbandExt, IterStream}, + Err, Result, }; +use futures::{FutureExt, StreamExt}; +use ruma::api::federation::space::get_hierarchy; +use service::rooms::spaces::{get_parent_children_via, Identifier, SummaryAccessibility}; use crate::Ruma; @@ -20,54 +21,51 @@ pub(crate) async fn get_hierarchy_route( return Err!(Request(NotFound("Room does not exist."))); } - get_hierarchy(&services, &body.room_id, body.origin(), body.suggested_only).await -} - -/// Gets the response for the space hierarchy over federation request -/// -/// Errors if the room does not exist, so a check if the room exists should -/// be done -async fn get_hierarchy( - services: &Services, - room_id: &RoomId, - server_name: &ServerName, - suggested_only: bool, -) -> Result { + let room_id = &body.room_id; + let suggested_only = body.suggested_only; + let ref identifier = Identifier::ServerName(body.origin()); match services .rooms .spaces - .get_summary_and_children_local(&room_id.to_owned(), Identifier::ServerName(server_name)) + .get_summary_and_children_local(room_id, identifier) .await? { - | Some(SummaryAccessibility::Accessible(room)) => { - let mut children = Vec::new(); - let mut inaccessible_children = Vec::new(); + | None => Err!(Request(NotFound("The requested room was not found"))), - for (child, _via) in get_parent_children_via(&room, suggested_only) { - match services - .rooms - .spaces - .get_summary_and_children_local(&child, Identifier::ServerName(server_name)) - .await? - { - | Some(SummaryAccessibility::Accessible(summary)) => { - children.push((*summary).into()); - }, - | Some(SummaryAccessibility::Inaccessible) => { - inaccessible_children.push(child); - }, - | None => (), - } - } - - Ok(get_hierarchy::v1::Response { - room: *room, - children, - inaccessible_children, - }) - }, | Some(SummaryAccessibility::Inaccessible) => Err!(Request(NotFound("The requested room is inaccessible"))), - | None => Err!(Request(NotFound("The requested room was not found"))), + + | Some(SummaryAccessibility::Accessible(room)) => { + let (children, inaccessible_children) = + get_parent_children_via(&room, suggested_only) + .stream() + .broad_filter_map(|(child, _via)| async move { + match services + .rooms + .spaces + .get_summary_and_children_local(&child, identifier) + .await + .ok()? + { + | None => None, + + | Some(SummaryAccessibility::Inaccessible) => + Some((None, Some(child))), + + | Some(SummaryAccessibility::Accessible(summary)) => + Some((Some(summary), None)), + } + }) + .unzip() + .map(|(children, inaccessible_children): (Vec<_>, Vec<_>)| { + ( + children.into_iter().flatten().map(Into::into).collect(), + inaccessible_children.into_iter().flatten().collect(), + ) + }) + .await; + + Ok(get_hierarchy::v1::Response { room, children, inaccessible_children }) + }, } } diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 1e2b0a9f..268d6dfe 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -1,14 +1,24 @@ mod pagination_token; +#[cfg(test)] mod tests; -use std::{collections::HashMap, sync::Arc}; +use std::sync::Arc; -use conduwuit::{debug_info, err, utils::math::usize_from_f64, Error, Result}; -use futures::StreamExt; +use conduwuit::{ + implement, + utils::{ + future::BoolExt, + math::usize_from_f64, + stream::{BroadbandExt, ReadyExt}, + IterStream, + }, + Err, Error, Result, +}; +use futures::{pin_mut, stream::FuturesUnordered, FutureExt, Stream, StreamExt, TryFutureExt}; use lru_cache::LruCache; use ruma::{ api::{ - client::{error::ErrorKind, space::SpaceHierarchyRoomsChunk}, + client::space::SpaceHierarchyRoomsChunk, federation::{ self, space::{SpaceHierarchyChildSummary, SpaceHierarchyParentSummary}, @@ -21,46 +31,46 @@ use ruma::{ }, serde::Raw, space::SpaceRoomJoinRule, - OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId, + OwnedEventId, OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId, }; -use tokio::sync::Mutex; +use tokio::sync::{Mutex, MutexGuard}; pub use self::pagination_token::PaginationToken; -use crate::{rooms, sending, Dep}; - -pub struct CachedSpaceHierarchySummary { - summary: SpaceHierarchyParentSummary, -} - -pub enum SummaryAccessibility { - Accessible(Box), - Inaccessible, -} - -/// Identifier used to check if rooms are accessible -/// -/// None is used if you want to return the room, no matter if accessible or not -pub enum Identifier<'a> { - UserId(&'a UserId), - ServerName(&'a ServerName), -} +use crate::{conduwuit::utils::TryFutureExtExt, rooms, sending, Dep}; pub struct Service { services: Services, - pub roomid_spacehierarchy_cache: - Mutex>>, + pub roomid_spacehierarchy_cache: Mutex, } struct Services { state_accessor: Dep, state_cache: Dep, state: Dep, - short: Dep, event_handler: Dep, timeline: Dep, sending: Dep, } +pub struct CachedSpaceHierarchySummary { + summary: SpaceHierarchyParentSummary, +} + +#[allow(clippy::large_enum_variant)] +pub enum SummaryAccessibility { + Accessible(SpaceHierarchyParentSummary), + Inaccessible, +} + +/// Identifier used to check if rooms are accessible. None is used if you want +/// to return the room, no matter if accessible or not +pub enum Identifier<'a> { + UserId(&'a UserId), + ServerName(&'a ServerName), +} + +type Cache = LruCache>; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { let config = &args.server.config; @@ -72,7 +82,6 @@ impl crate::Service for Service { .depend::("rooms::state_accessor"), state_cache: args.depend::("rooms::state_cache"), state: args.depend::("rooms::state"), - short: args.depend::("rooms::short"), event_handler: args .depend::("rooms::event_handler"), timeline: args.depend::("rooms::timeline"), @@ -85,370 +94,407 @@ impl crate::Service for Service { fn name(&self) -> &str { crate::service::make_name(std::module_path!()) } } -impl Service { - /// Gets the summary of a space using solely local information - pub async fn get_summary_and_children_local( - &self, - current_room: &OwnedRoomId, - identifier: Identifier<'_>, - ) -> Result> { - if let Some(cached) = self - .roomid_spacehierarchy_cache - .lock() - .await - .get_mut(¤t_room.to_owned()) - .as_ref() - { - return Ok(if let Some(cached) = cached { +/// Gets the summary of a space using solely local information +#[implement(Service)] +pub async fn get_summary_and_children_local( + &self, + current_room: &RoomId, + identifier: &Identifier<'_>, +) -> Result> { + match self + .roomid_spacehierarchy_cache + .lock() + .await + .get_mut(current_room) + .as_ref() + { + | None => (), // cache miss + | Some(None) => return Ok(None), + | Some(Some(cached)) => + return Ok(Some( if self .is_accessible_child( current_room, &cached.summary.join_rule, - &identifier, + identifier, &cached.summary.allowed_room_ids, ) .await { - Some(SummaryAccessibility::Accessible(Box::new(cached.summary.clone()))) + SummaryAccessibility::Accessible(cached.summary.clone()) } else { - Some(SummaryAccessibility::Inaccessible) - } - } else { - None - }); - } + SummaryAccessibility::Inaccessible + }, + )), + }; - if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? { - let summary = self - .get_room_summary(current_room, children_pdus, &identifier) - .await; - if let Ok(summary) = summary { - self.roomid_spacehierarchy_cache.lock().await.insert( - current_room.clone(), - Some(CachedSpaceHierarchySummary { summary: summary.clone() }), - ); + let children_pdus: Vec<_> = self + .get_stripped_space_child_events(current_room) + .collect() + .await; - Ok(Some(SummaryAccessibility::Accessible(Box::new(summary)))) - } else { - Ok(None) - } - } else { - Ok(None) - } - } + let summary = self + .get_room_summary(current_room, children_pdus, identifier) + .boxed() + .await; - /// Gets the summary of a space using solely federation - #[tracing::instrument(level = "debug", skip(self))] - async fn get_summary_and_children_federation( - &self, - current_room: &OwnedRoomId, - suggested_only: bool, - user_id: &UserId, - via: &[OwnedServerName], - ) -> Result> { - for server in via { - debug_info!("Asking {server} for /hierarchy"); - let Ok(response) = self - .services + let Ok(summary) = summary else { + return Ok(None); + }; + + self.roomid_spacehierarchy_cache.lock().await.insert( + current_room.to_owned(), + Some(CachedSpaceHierarchySummary { summary: summary.clone() }), + ); + + Ok(Some(SummaryAccessibility::Accessible(summary))) +} + +/// Gets the summary of a space using solely federation +#[implement(Service)] +#[tracing::instrument(level = "debug", skip(self))] +async fn get_summary_and_children_federation( + &self, + current_room: &RoomId, + suggested_only: bool, + user_id: &UserId, + via: &[OwnedServerName], +) -> Result> { + let request = federation::space::get_hierarchy::v1::Request { + room_id: current_room.to_owned(), + suggested_only, + }; + + let mut requests: FuturesUnordered<_> = via + .iter() + .map(|server| { + self.services .sending - .send_federation_request(server, federation::space::get_hierarchy::v1::Request { - room_id: current_room.to_owned(), - suggested_only, - }) - .await - else { - continue; - }; - - debug_info!("Got response from {server} for /hierarchy\n{response:?}"); - let summary = response.room.clone(); - - self.roomid_spacehierarchy_cache.lock().await.insert( - current_room.clone(), - Some(CachedSpaceHierarchySummary { summary: summary.clone() }), - ); - - for child in response.children { - let mut guard = self.roomid_spacehierarchy_cache.lock().await; - if !guard.contains_key(current_room) { - guard.insert( - current_room.clone(), - Some(CachedSpaceHierarchySummary { - summary: { - let SpaceHierarchyChildSummary { - canonical_alias, - name, - num_joined_members, - room_id, - topic, - world_readable, - guest_can_join, - avatar_url, - join_rule, - room_type, - allowed_room_ids, - } = child; - - SpaceHierarchyParentSummary { - canonical_alias, - name, - num_joined_members, - room_id: room_id.clone(), - topic, - world_readable, - guest_can_join, - avatar_url, - join_rule, - room_type, - children_state: self - .get_stripped_space_child_events(&room_id) - .await? - .unwrap(), - allowed_room_ids, - } - }, - }), - ); - } - } - if self - .is_accessible_child( - current_room, - &response.room.join_rule, - &Identifier::UserId(user_id), - &response.room.allowed_room_ids, - ) - .await - { - return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone())))); - } - - return Ok(Some(SummaryAccessibility::Inaccessible)); - } + .send_federation_request(server, request.clone()) + }) + .collect(); + let Some(Ok(response)) = requests.next().await else { self.roomid_spacehierarchy_cache .lock() .await - .insert(current_room.clone(), None); + .insert(current_room.to_owned(), None); - Ok(None) - } + return Ok(None); + }; - /// Gets the summary of a space using either local or remote (federation) - /// sources - pub async fn get_summary_and_children_client( - &self, - current_room: &OwnedRoomId, - suggested_only: bool, - user_id: &UserId, - via: &[OwnedServerName], - ) -> Result> { - if let Ok(Some(response)) = self - .get_summary_and_children_local(current_room, Identifier::UserId(user_id)) - .await - { - Ok(Some(response)) - } else { - self.get_summary_and_children_federation(current_room, suggested_only, user_id, via) - .await - } - } + let summary = response.room; + self.roomid_spacehierarchy_cache.lock().await.insert( + current_room.to_owned(), + Some(CachedSpaceHierarchySummary { summary: summary.clone() }), + ); - async fn get_room_summary( - &self, - current_room: &OwnedRoomId, - children_state: Vec>, - identifier: &Identifier<'_>, - ) -> Result { - let room_id: &RoomId = current_room; - - let join_rule = self - .services - .state_accessor - .room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") - .await - .map_or(JoinRule::Invite, |c: RoomJoinRulesEventContent| c.join_rule); - - let allowed_room_ids = self - .services - .state_accessor - .allowed_room_ids(join_rule.clone()); - - if !self - .is_accessible_child( - current_room, - &join_rule.clone().into(), - identifier, - &allowed_room_ids, - ) - .await - { - debug_info!("User is not allowed to see room {room_id}"); - // This error will be caught later - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "User is not allowed to see the room", - )); - } - - Ok(SpaceHierarchyParentSummary { - canonical_alias: self - .services - .state_accessor - .get_canonical_alias(room_id) - .await - .ok(), - name: self.services.state_accessor.get_name(room_id).await.ok(), - num_joined_members: self - .services - .state_cache - .room_joined_count(room_id) - .await - .unwrap_or(0) - .try_into() - .expect("user count should not be that big"), - room_id: room_id.to_owned(), - topic: self - .services - .state_accessor - .get_room_topic(room_id) - .await - .ok(), - world_readable: self - .services - .state_accessor - .is_world_readable(room_id) - .await, - guest_can_join: self.services.state_accessor.guest_can_join(room_id).await, - avatar_url: self - .services - .state_accessor - .get_avatar(room_id) - .await - .into_option() - .unwrap_or_default() - .url, - join_rule: join_rule.into(), - room_type: self - .services - .state_accessor - .get_room_type(room_id) - .await - .ok(), - children_state, - allowed_room_ids, + response + .children + .into_iter() + .stream() + .then(|child| { + self.roomid_spacehierarchy_cache + .lock() + .map(|lock| (child, lock)) }) + .ready_filter_map(|(child, mut cache)| { + (!cache.contains_key(current_room)).then_some((child, cache)) + }) + .for_each(|(child, cache)| self.cache_insert(cache, current_room, child)) + .await; + + let identifier = Identifier::UserId(user_id); + let is_accessible_child = self + .is_accessible_child( + current_room, + &summary.join_rule, + &identifier, + &summary.allowed_room_ids, + ) + .await; + + if is_accessible_child { + return Ok(Some(SummaryAccessibility::Accessible(summary))); } - /// Simply returns the stripped m.space.child events of a room - async fn get_stripped_space_child_events( - &self, - room_id: &RoomId, - ) -> Result>>, Error> { - let Ok(current_shortstatehash) = - self.services.state.get_room_shortstatehash(room_id).await - else { - return Ok(None); - }; - - let state: HashMap<_, Arc<_>> = self - .services - .state_accessor - .state_full_ids(current_shortstatehash) - .collect() - .await; - - let mut children_pdus = Vec::with_capacity(state.len()); - for (key, id) in state { - let (event_type, state_key) = - self.services.short.get_statekey_from_short(key).await?; - - if event_type != StateEventType::SpaceChild { - continue; - } - - let pdu = - self.services.timeline.get_pdu(&id).await.map_err(|e| { - err!(Database("Event {id:?} in space state not found: {e:?}")) - })?; + Ok(Some(SummaryAccessibility::Inaccessible)) +} +/// Simply returns the stripped m.space.child events of a room +#[implement(Service)] +fn get_stripped_space_child_events<'a>( + &'a self, + room_id: &'a RoomId, +) -> impl Stream> + 'a { + self.services + .state + .get_room_shortstatehash(room_id) + .map_ok(|current_shortstatehash| { + self.services + .state_accessor + .state_keys_with_ids(current_shortstatehash, &StateEventType::SpaceChild) + }) + .map(Result::into_iter) + .map(IterStream::stream) + .map(StreamExt::flatten) + .flatten_stream() + .broad_filter_map(move |(state_key, event_id): (_, OwnedEventId)| async move { + self.services + .timeline + .get_pdu(&event_id) + .map_ok(move |pdu| (state_key, pdu)) + .await + .ok() + }) + .ready_filter_map(move |(state_key, pdu)| { if let Ok(content) = pdu.get_content::() { if content.via.is_empty() { - continue; + return None; } } - if OwnedRoomId::try_from(state_key).is_ok() { - children_pdus.push(pdu.to_stripped_spacechild_state_event()); + if RoomId::parse(&state_key).is_ok() { + return Some(pdu.to_stripped_spacechild_state_event()); } - } - Ok(Some(children_pdus)) + None + }) +} + +/// Gets the summary of a space using either local or remote (federation) +/// sources +#[implement(Service)] +pub async fn get_summary_and_children_client( + &self, + current_room: &OwnedRoomId, + suggested_only: bool, + user_id: &UserId, + via: &[OwnedServerName], +) -> Result> { + let identifier = Identifier::UserId(user_id); + + if let Ok(Some(response)) = self + .get_summary_and_children_local(current_room, &identifier) + .await + { + return Ok(Some(response)); } - /// With the given identifier, checks if a room is accessable - async fn is_accessible_child( - &self, - current_room: &OwnedRoomId, - join_rule: &SpaceRoomJoinRule, - identifier: &Identifier<'_>, - allowed_room_ids: &Vec, - ) -> bool { - match identifier { - | Identifier::ServerName(server_name) => { - // Checks if ACLs allow for the server to participate - if self - .services - .event_handler - .acl_check(server_name, current_room) - .await - .is_err() - { - return false; - } - }, - | Identifier::UserId(user_id) => { - if self - .services - .state_cache - .is_joined(user_id, current_room) - .await || self - .services - .state_cache - .is_invited(user_id, current_room) - .await - { - return true; - } - }, + self.get_summary_and_children_federation(current_room, suggested_only, user_id, via) + .await +} + +#[implement(Service)] +async fn get_room_summary( + &self, + room_id: &RoomId, + children_state: Vec>, + identifier: &Identifier<'_>, +) -> Result { + let join_rule = self + .services + .state_accessor + .room_state_get_content(room_id, &StateEventType::RoomJoinRules, "") + .await + .map_or(JoinRule::Invite, |c: RoomJoinRulesEventContent| c.join_rule); + + let allowed_room_ids = self + .services + .state_accessor + .allowed_room_ids(join_rule.clone()); + + let join_rule = join_rule.clone().into(); + let is_accessible_child = self + .is_accessible_child(room_id, &join_rule, identifier, &allowed_room_ids) + .await; + + if !is_accessible_child { + return Err!(Request(Forbidden("User is not allowed to see the room",))); + } + + let name = self.services.state_accessor.get_name(room_id).ok(); + + let topic = self.services.state_accessor.get_room_topic(room_id).ok(); + + let room_type = self.services.state_accessor.get_room_type(room_id).ok(); + + let world_readable = self.services.state_accessor.is_world_readable(room_id); + + let guest_can_join = self.services.state_accessor.guest_can_join(room_id); + + let num_joined_members = self + .services + .state_cache + .room_joined_count(room_id) + .unwrap_or(0); + + let canonical_alias = self + .services + .state_accessor + .get_canonical_alias(room_id) + .ok(); + + let avatar_url = self + .services + .state_accessor + .get_avatar(room_id) + .map(|res| res.into_option().unwrap_or_default().url); + + let ( + canonical_alias, + name, + num_joined_members, + topic, + world_readable, + guest_can_join, + avatar_url, + room_type, + ) = futures::join!( + canonical_alias, + name, + num_joined_members, + topic, + world_readable, + guest_can_join, + avatar_url, + room_type + ); + + Ok(SpaceHierarchyParentSummary { + canonical_alias, + name, + topic, + world_readable, + guest_can_join, + avatar_url, + room_type, + children_state, + allowed_room_ids, + join_rule, + room_id: room_id.to_owned(), + num_joined_members: num_joined_members + .try_into() + .expect("user count should not be that big"), + }) +} + +/// With the given identifier, checks if a room is accessable +#[implement(Service)] +async fn is_accessible_child( + &self, + current_room: &RoomId, + join_rule: &SpaceRoomJoinRule, + identifier: &Identifier<'_>, + allowed_room_ids: &[OwnedRoomId], +) -> bool { + if let Identifier::ServerName(server_name) = identifier { + // Checks if ACLs allow for the server to participate + if self + .services + .event_handler + .acl_check(server_name, current_room) + .await + .is_err() + { + return false; } - match &join_rule { - | SpaceRoomJoinRule::Public - | SpaceRoomJoinRule::Knock - | SpaceRoomJoinRule::KnockRestricted => true, - | SpaceRoomJoinRule::Restricted => { - for room in allowed_room_ids { + } + + if let Identifier::UserId(user_id) = identifier { + let is_joined = self.services.state_cache.is_joined(user_id, current_room); + + let is_invited = self.services.state_cache.is_invited(user_id, current_room); + + pin_mut!(is_joined, is_invited); + if is_joined.or(is_invited).await { + return true; + } + } + + match join_rule { + | SpaceRoomJoinRule::Public + | SpaceRoomJoinRule::Knock + | SpaceRoomJoinRule::KnockRestricted => true, + | SpaceRoomJoinRule::Restricted => + allowed_room_ids + .iter() + .stream() + .any(|room| async { match identifier { - | Identifier::UserId(user) => { - if self.services.state_cache.is_joined(user, room).await { - return true; - } - }, - | Identifier::ServerName(server) => { - if self.services.state_cache.server_in_room(server, room).await { - return true; - } - }, + | Identifier::UserId(user) => + self.services.state_cache.is_joined(user, room).await, + | Identifier::ServerName(server) => + self.services.state_cache.server_in_room(server, room).await, } - } - false - }, - // Invite only, Private, or Custom join rule - | _ => false, - } + }) + .await, + + // Invite only, Private, or Custom join rule + | _ => false, } } +/// Returns the children of a SpaceHierarchyParentSummary, making use of the +/// children_state field +pub fn get_parent_children_via( + parent: &SpaceHierarchyParentSummary, + suggested_only: bool, +) -> impl DoubleEndedIterator)> + Send + '_ +{ + parent + .children_state + .iter() + .map(Raw::deserialize) + .filter_map(Result::ok) + .filter_map(move |ce| { + (!suggested_only || ce.content.suggested) + .then_some((ce.state_key, ce.content.via.into_iter())) + }) +} + +#[implement(Service)] +async fn cache_insert( + &self, + mut cache: MutexGuard<'_, Cache>, + current_room: &RoomId, + child: SpaceHierarchyChildSummary, +) { + let SpaceHierarchyChildSummary { + canonical_alias, + name, + num_joined_members, + room_id, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + allowed_room_ids, + } = child; + + let summary = SpaceHierarchyParentSummary { + canonical_alias, + name, + num_joined_members, + topic, + world_readable, + guest_can_join, + avatar_url, + join_rule, + room_type, + allowed_room_ids, + room_id: room_id.clone(), + children_state: self + .get_stripped_space_child_events(&room_id) + .collect() + .await, + }; + + cache.insert(current_room.to_owned(), Some(CachedSpaceHierarchySummary { summary })); +} + // Here because cannot implement `From` across ruma-federation-api and // ruma-client-api types impl From for SpaceHierarchyRoomsChunk { @@ -517,25 +563,3 @@ pub fn summary_to_chunk(summary: SpaceHierarchyParentSummary) -> SpaceHierarchyR children_state, } } - -/// Returns the children of a SpaceHierarchyParentSummary, making use of the -/// children_state field -#[must_use] -pub fn get_parent_children_via( - parent: &SpaceHierarchyParentSummary, - suggested_only: bool, -) -> Vec<(OwnedRoomId, Vec)> { - parent - .children_state - .iter() - .filter_map(|raw_ce| { - raw_ce.deserialize().map_or(None, |ce| { - if suggested_only && !ce.content.suggested { - None - } else { - Some((ce.state_key, ce.content.via)) - } - }) - }) - .collect() -} diff --git a/src/service/rooms/spaces/tests.rs b/src/service/rooms/spaces/tests.rs index b4c387d7..dd6c2f35 100644 --- a/src/service/rooms/spaces/tests.rs +++ b/src/service/rooms/spaces/tests.rs @@ -1,5 +1,3 @@ -#![cfg(test)] - use std::str::FromStr; use ruma::{ @@ -69,15 +67,22 @@ fn get_summary_children() { } .into(); - assert_eq!(get_parent_children_via(&summary, false), vec![ - (owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]), - (owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]), - (owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")]) - ]); - assert_eq!(get_parent_children_via(&summary, true), vec![( - owned_room_id!("!bar:example.org"), - vec![owned_server_name!("example.org")] - )]); + assert_eq!( + get_parent_children_via(&summary, false) + .map(|(k, v)| (k, v.collect::>())) + .collect::>(), + vec![ + (owned_room_id!("!foo:example.org"), vec![owned_server_name!("example.org")]), + (owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")]), + (owned_room_id!("!baz:example.org"), vec![owned_server_name!("example.org")]) + ] + ); + assert_eq!( + get_parent_children_via(&summary, true) + .map(|(k, v)| (k, v.collect::>())) + .collect::>(), + vec![(owned_room_id!("!bar:example.org"), vec![owned_server_name!("example.org")])] + ); } #[test]