diff --git a/src/api/client/context.rs b/src/api/client/context.rs index af4e26f0..acd7d80b 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -2,7 +2,7 @@ use std::iter::once; use axum::extract::State; use conduit::{ - at, err, + at, err, ref_at, utils::{ future::TryExtExt, stream::{BroadbandExt, ReadyExt, WidebandExt}, @@ -10,7 +10,7 @@ use conduit::{ }, Err, Result, }; -use futures::{future::try_join, StreamExt, TryFutureExt}; +use futures::{join, try_join, FutureExt, StreamExt, TryFutureExt}; use ruma::{ api::client::{context::get_context, filter::LazyLoadOptions}, events::StateEventType, @@ -37,6 +37,7 @@ pub(crate) async fn get_context_route( let filter = &body.filter; let sender = body.sender(); let (sender_user, _) = sender; + let room_id = &body.room_id; // Use limit or else 10, with maximum 100 let limit: usize = body @@ -70,42 +71,63 @@ pub(crate) async fn get_context_route( .get_pdu(&body.event_id) .map_err(|_| err!(Request(NotFound("Base event not found.")))); - let (base_token, base_event) = try_join(base_token, base_event).await?; - - let room_id = &base_event.room_id; - - if !services + let visible = services .rooms .state_accessor - .user_can_see_event(sender_user, room_id, &body.event_id) - .await - { + .user_can_see_event(sender_user, &body.room_id, &body.event_id) + .map(Ok); + + let (base_token, base_event, visible) = try_join!(base_token, base_event, visible)?; + + if base_event.room_id != body.room_id { + return Err!(Request(NotFound("Base event not found."))); + } + + if !visible { return Err!(Request(Forbidden("You don't have permission to view this event."))); } - let events_before: Vec<_> = services + let events_before = services .rooms .timeline - .pdus_rev(Some(sender_user), room_id, Some(base_token)) - .await? - .ready_filter_map(|item| event_filter(item, filter)) - .wide_filter_map(|item| ignored_filter(&services, item, sender_user)) - .wide_filter_map(|item| visibility_filter(&services, item, sender_user)) - .take(limit / 2) - .collect() - .await; + .pdus_rev(Some(sender_user), room_id, Some(base_token)); - let events_after: Vec<_> = services + let events_after = services .rooms .timeline - .pdus(Some(sender_user), room_id, Some(base_token)) - .await? + .pdus(Some(sender_user), room_id, Some(base_token)); + + let (events_before, events_after) = try_join!(events_before, events_after)?; + + let events_before = events_before .ready_filter_map(|item| event_filter(item, filter)) .wide_filter_map(|item| ignored_filter(&services, item, sender_user)) .wide_filter_map(|item| visibility_filter(&services, item, sender_user)) .take(limit / 2) - .collect() - .await; + .collect(); + + let events_after = events_after + .ready_filter_map(|item| event_filter(item, filter)) + .wide_filter_map(|item| ignored_filter(&services, item, sender_user)) + .wide_filter_map(|item| visibility_filter(&services, item, sender_user)) + .take(limit / 2) + .collect(); + + let (events_before, events_after): (Vec<_>, Vec<_>) = join!(events_before, events_after); + + let state_at = events_after + .last() + .map(ref_at!(1)) + .map_or(body.event_id.as_ref(), |e| e.event_id.as_ref()); + + let state_ids = services + .rooms + .state_accessor + .pdu_shortstatehash(state_at) + .or_else(|_| services.rooms.state.get_room_shortstatehash(room_id)) + .and_then(|shortstatehash| services.rooms.state_accessor.state_full_ids(shortstatehash)) + .map_err(|e| err!(Database("State not found: {e}"))) + .await?; let lazy = once(&(base_token, base_event.clone())) .chain(events_before.iter()) @@ -116,48 +138,31 @@ pub(crate) async fn get_context_route( }) .await; - let state_id = events_after - .last() - .map_or(body.event_id.as_ref(), |(_, e)| e.event_id.as_ref()); - - let shortstatehash = services - .rooms - .state_accessor - .pdu_shortstatehash(state_id) - .or_else(|_| services.rooms.state.get_room_shortstatehash(room_id)) - .await - .map_err(|e| err!(Database("State hash not found: {e}")))?; - - let state_ids = services - .rooms - .state_accessor - .state_full_ids(shortstatehash) - .await - .map_err(|e| err!(Database("State not found: {e}")))?; - let lazy = &lazy; let state: Vec<_> = state_ids - .into_iter() + .iter() .stream() .broad_filter_map(|(shortstatekey, event_id)| { services .rooms .short - .get_statekey_from_short(shortstatekey) + .get_statekey_from_short(*shortstatekey) .map_ok(move |(event_type, state_key)| (event_type, state_key, event_id)) .ok() }) .ready_filter_map(|(event_type, state_key, event_id)| { - if lazy_load_enabled && event_type == StateEventType::RoomMember { - let user_id: &UserId = state_key.as_str().try_into().ok()?; - if !lazy.contains(user_id) { - return None; - } + if !lazy_load_enabled || event_type != StateEventType::RoomMember { + return Some(event_id); } - Some(event_id) + state_key + .as_str() + .try_into() + .ok() + .filter(|&user_id: &&UserId| lazy.contains(user_id)) + .map(|_| event_id) }) - .broad_filter_map(|event_id: OwnedEventId| async move { services.rooms.timeline.get_pdu(&event_id).await.ok() }) + .broad_filter_map(|event_id: &OwnedEventId| services.rooms.timeline.get_pdu(event_id).ok()) .map(|pdu| pdu.to_state_event()) .collect() .await;