From 21a67513f2480e6cb1cb0322e15016ba8d919dac Mon Sep 17 00:00:00 2001 From: Jason Volk Date: Sat, 26 Oct 2024 22:21:23 +0000 Subject: [PATCH] refactor search system Signed-off-by: Jason Volk --- Cargo.lock | 1 + src/api/client/search.rs | 339 +++++++++++++++++--------------- src/service/Cargo.toml | 1 + src/service/rooms/search/mod.rs | 176 +++++++++++++---- 4 files changed, 312 insertions(+), 205 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index c64d3cc6..a8acce7d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -786,6 +786,7 @@ dependencies = [ name = "conduit_service" version = "0.5.0" dependencies = [ + "arrayvec", "async-trait", "base64 0.22.1", "bytes", diff --git a/src/api/client/search.rs b/src/api/client/search.rs index b073640e..1e5384fe 100644 --- a/src/api/client/search.rs +++ b/src/api/client/search.rs @@ -2,25 +2,32 @@ use std::collections::BTreeMap; use axum::extract::State; use conduit::{ - debug, - utils::{IterStream, ReadyExt}, - Err, + at, is_true, + result::FlatOk, + utils::{stream::ReadyExt, IterStream}, + Err, PduEvent, Result, }; -use futures::{FutureExt, StreamExt}; +use futures::{future::OptionFuture, FutureExt, StreamExt, TryFutureExt}; use ruma::{ - api::client::{ - error::ErrorKind, - search::search_events::{ - self, - v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, - }, + api::client::search::search_events::{ + self, + v3::{Criteria, EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, }, events::AnyStateEvent, serde::Raw, - uint, OwnedRoomId, + OwnedRoomId, RoomId, UInt, UserId, }; +use search_events::v3::{Request, Response}; +use service::{rooms::search::RoomQuery, Services}; -use crate::{Error, Result, Ruma}; +use crate::Ruma; + +type RoomStates = BTreeMap; +type RoomState = Vec>; + +const LIMIT_DEFAULT: usize = 10; +const LIMIT_MAX: usize = 100; +const BATCH_MAX: usize = 20; /// # `POST /_matrix/client/r0/search` /// @@ -28,173 +35,177 @@ use crate::{Error, Result, Ruma}; /// /// - Only works if the user is currently joined to the room (TODO: Respect /// history visibility) -pub(crate) async fn search_events_route( - State(services): State, body: Ruma, -) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); +pub(crate) async fn search_events_route(State(services): State, body: Ruma) -> Result { + let sender_user = body.sender_user(); + let next_batch = body.next_batch.as_deref(); + let room_events_result: OptionFuture<_> = body + .search_categories + .room_events + .as_ref() + .map(|criteria| category_room_events(&services, sender_user, next_batch, criteria)) + .into(); - let search_criteria = body.search_categories.room_events.as_ref().unwrap(); - let filter = &search_criteria.filter; - let include_state = &search_criteria.include_state; + Ok(Response { + search_categories: ResultCategories { + room_events: room_events_result + .await + .unwrap_or_else(|| Ok(ResultRoomEvents::default()))?, + }, + }) +} - let room_ids = if let Some(room_ids) = &filter.rooms { - room_ids.clone() - } else { - services - .rooms - .state_cache - .rooms_joined(sender_user) - .map(ToOwned::to_owned) - .collect() - .await - }; +#[allow(clippy::map_unwrap_or)] +async fn category_room_events( + services: &Services, sender_user: &UserId, next_batch: Option<&str>, criteria: &Criteria, +) -> Result { + let filter = &criteria.filter; - // Use limit or else 10, with maximum 100 let limit: usize = filter .limit - .unwrap_or_else(|| uint!(10)) - .try_into() - .unwrap_or(10) - .min(100); + .map(TryInto::try_into) + .flat_ok() + .unwrap_or(LIMIT_DEFAULT) + .min(LIMIT_MAX); - let mut room_states: BTreeMap>> = BTreeMap::new(); + let next_batch: usize = next_batch + .map(str::parse) + .transpose()? + .unwrap_or(0) + .min(limit.saturating_mul(BATCH_MAX)); - if include_state.is_some_and(|include_state| include_state) { - for room_id in &room_ids { - if !services - .rooms - .state_cache - .is_joined(sender_user, room_id) - .await - { - return Err!(Request(Forbidden("You don't have permission to view this room."))); - } - - // check if sender_user can see state events - if services - .rooms - .state_accessor - .user_can_see_state_events(sender_user, room_id) - .await - { - let room_state: Vec<_> = services - .rooms - .state_accessor - .room_state_full(room_id) - .await? - .values() - .map(|pdu| pdu.to_state_event()) - .collect(); - - debug!("Room state: {:?}", room_state); - - room_states.insert(room_id.clone(), room_state); - } else { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); - } - } - } - - let mut search_vecs = Vec::new(); - - for room_id in &room_ids { - if !services - .rooms - .state_cache - .is_joined(sender_user, room_id) - .await - { - return Err(Error::BadRequest( - ErrorKind::forbidden(), - "You don't have permission to view this room.", - )); - } - - if let Some(search) = services - .rooms - .search - .search_pdus(room_id, &search_criteria.search_term) - .await - { - search_vecs.push(search.0); - } - } - - let mut searches: Vec<_> = search_vecs - .iter() - .map(|vec| vec.iter().peekable()) - .collect(); - - let skip: usize = match body.next_batch.as_ref().map(|s| s.parse()) { - Some(Ok(s)) => s, - Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")), - None => 0, // Default to the start - }; - - let mut results = Vec::new(); - let next_batch = skip.saturating_add(limit); - - for _ in 0..next_batch { - if let Some(s) = searches - .iter_mut() - .map(|s| (s.peek().copied(), s)) - .max_by_key(|(peek, _)| *peek) - .and_then(|(_, i)| i.next()) - { - results.push(s); - } - } - - let results: Vec<_> = results - .into_iter() - .skip(skip) - .stream() - .filter_map(|id| services.rooms.timeline.get_pdu_from_id(id).map(Result::ok)) - .ready_filter(|pdu| !pdu.is_redacted()) - .filter_map(|pdu| async move { + let rooms = filter + .rooms + .clone() + .map(IntoIterator::into_iter) + .map(IterStream::stream) + .map(StreamExt::boxed) + .unwrap_or_else(|| { services .rooms - .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) + .state_cache + .rooms_joined(sender_user) + .map(ToOwned::to_owned) + .boxed() + }); + + let results: Vec<_> = rooms + .filter_map(|room_id| async move { + check_room_visible(services, sender_user, &room_id, criteria) .await - .then_some(pdu) + .is_ok() + .then_some(room_id) }) - .take(limit) - .map(|pdu| pdu.to_room_event()) - .map(|result| SearchResult { - context: EventContextResult { - end: None, - events_after: Vec::new(), - events_before: Vec::new(), - profile_info: BTreeMap::new(), - start: None, - }, - rank: None, - result: Some(result), + .filter_map(|room_id| async move { + let query = RoomQuery { + room_id: &room_id, + user_id: Some(sender_user), + criteria, + skip: next_batch, + limit, + }; + + let (count, results) = services.rooms.search.search_pdus(&query).await.ok()?; + + results + .collect::>() + .map(|results| (room_id.clone(), count, results)) + .map(Some) + .await }) .collect() - .boxed() .await; - let more_unloaded_results = searches.iter_mut().any(|s| s.peek().is_some()); + let total: UInt = results + .iter() + .fold(0, |a: usize, (_, count, _)| a.saturating_add(*count)) + .try_into()?; - let next_batch = more_unloaded_results.then(|| next_batch.to_string()); + let state: RoomStates = results + .iter() + .stream() + .ready_filter(|_| criteria.include_state.is_some_and(is_true!())) + .filter_map(|(room_id, ..)| async move { + procure_room_state(services, room_id) + .map_ok(|state| (room_id.clone(), state)) + .await + .ok() + }) + .collect() + .await; - Ok(search_events::v3::Response::new(ResultCategories { - room_events: ResultRoomEvents { - count: Some(results.len().try_into().unwrap_or_else(|_| uint!(0))), - groups: BTreeMap::new(), // TODO - next_batch, - results, - state: room_states, - highlights: search_criteria - .search_term - .split_terminator(|c: char| !c.is_alphanumeric()) - .map(str::to_lowercase) - .collect(), - }, - })) + let results: Vec = results + .into_iter() + .map(at!(2)) + .flatten() + .stream() + .map(|pdu| pdu.to_room_event()) + .map(|result| SearchResult { + rank: None, + result: Some(result), + context: EventContextResult { + profile_info: BTreeMap::new(), //TODO + events_after: Vec::new(), //TODO + events_before: Vec::new(), //TODO + start: None, //TODO + end: None, //TODO + }, + }) + .collect() + .await; + + let highlights = criteria + .search_term + .split_terminator(|c: char| !c.is_alphanumeric()) + .map(str::to_lowercase) + .collect(); + + let next_batch = (results.len() >= limit) + .then_some(next_batch.saturating_add(results.len())) + .as_ref() + .map(ToString::to_string); + + Ok(ResultRoomEvents { + count: Some(total), + next_batch, + results, + state, + highlights, + groups: BTreeMap::new(), // TODO + }) +} + +async fn procure_room_state(services: &Services, room_id: &RoomId) -> Result { + let state_map = services + .rooms + .state_accessor + .room_state_full(room_id) + .await?; + + let state_events = state_map + .values() + .map(AsRef::as_ref) + .map(PduEvent::to_state_event) + .collect(); + + Ok(state_events) +} + +async fn check_room_visible(services: &Services, user_id: &UserId, room_id: &RoomId, search: &Criteria) -> Result { + let check_visible = search.filter.rooms.is_some(); + let check_state = check_visible && search.include_state.is_some_and(is_true!()); + + let is_joined = !check_visible || services.rooms.state_cache.is_joined(user_id, room_id).await; + + let state_visible = !check_state + || services + .rooms + .state_accessor + .user_can_see_state_events(user_id, room_id) + .await; + + if !is_joined || !state_visible { + return Err!(Request(Forbidden("You don't have permission to view {room_id:?}"))); + } + + Ok(()) } diff --git a/src/service/Cargo.toml b/src/service/Cargo.toml index 737a7039..7578ef64 100644 --- a/src/service/Cargo.toml +++ b/src/service/Cargo.toml @@ -40,6 +40,7 @@ release_max_log_level = [ ] [dependencies] +arrayvec.workspace = true async-trait.workspace = true base64.workspace = true bytes.workspace = true diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 032ad55c..8882ec99 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,15 +1,23 @@ -use std::sync::Arc; +use std::{iter, sync::Arc}; +use arrayvec::ArrayVec; use conduit::{ implement, - utils::{set, stream::TryIgnore, IterStream, ReadyExt}, - Result, + utils::{set, stream::TryIgnore, ArrayVecExt, IterStream, ReadyExt}, + PduEvent, Result, }; -use database::Map; -use futures::StreamExt; -use ruma::RoomId; +use database::{keyval::Val, Map}; +use futures::{Stream, StreamExt}; +use ruma::{api::client::search::search_events::v3::Criteria, RoomId, UserId}; -use crate::{rooms, Dep}; +use crate::{ + rooms, + rooms::{ + short::ShortRoomId, + timeline::{PduId, RawPduId}, + }, + Dep, +}; pub struct Service { db: Data, @@ -22,8 +30,24 @@ struct Data { struct Services { short: Dep, + state_accessor: Dep, + timeline: Dep, } +#[derive(Clone, Debug)] +pub struct RoomQuery<'a> { + pub room_id: &'a RoomId, + pub user_id: Option<&'a UserId>, + pub criteria: &'a Criteria, + pub limit: usize, + pub skip: usize, +} + +type TokenId = ArrayVec; + +const TOKEN_ID_MAX_LEN: usize = size_of::() + WORD_MAX_LEN + 1 + size_of::(); +const WORD_MAX_LEN: usize = 50; + impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { @@ -32,6 +56,8 @@ impl crate::Service for Service { }, services: Services { short: args.depend::("rooms::short"), + state_accessor: args.depend::("rooms::state_accessor"), + timeline: args.depend::("rooms::timeline"), }, })) } @@ -70,46 +96,92 @@ pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { } #[implement(Service)] -pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option<(Vec>, Vec)> { - let prefix = self - .services - .short - .get_shortroomid(room_id) - .await - .ok()? - .to_be_bytes() - .to_vec(); +pub async fn search_pdus<'a>( + &'a self, query: &'a RoomQuery<'a>, +) -> Result<(usize, impl Stream + Send + 'a)> { + let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await; - let words: Vec<_> = tokenize(search_string).collect(); - - let bufs: Vec<_> = words - .clone() + let count = pdu_ids.len(); + let pdus = pdu_ids .into_iter() .stream() - .then(move |word| { - let mut prefix2 = prefix.clone(); - prefix2.extend_from_slice(word.as_bytes()); - prefix2.push(0xFF); - let prefix3 = prefix2.clone(); - - let mut last_possible_id = prefix2.clone(); - last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes()); - - self.db.tokenids - .rev_raw_keys_from(&last_possible_id) // Newest pdus first - .ignore_err() - .ready_take_while(move |key| key.starts_with(&prefix2)) - .map(move |key| key[prefix3.len()..].to_vec()) - .collect::>() + .filter_map(move |result_pdu_id: RawPduId| async move { + self.services + .timeline + .get_pdu_from_id(&result_pdu_id) + .await + .ok() }) - .collect() - .await; + .ready_filter(|pdu| !pdu.is_redacted()) + .filter_map(move |pdu| async move { + self.services + .state_accessor + .user_can_see_event(query.user_id?, &pdu.room_id, &pdu.event_id) + .await + .then_some(pdu) + }) + .skip(query.skip) + .take(query.limit); - let bufs = bufs.iter().map(|buf| buf.iter()); + Ok((count, pdus)) +} - let results = set::intersection(bufs).cloned().collect(); +// result is modeled as a stream such that callers don't have to be refactored +// though an additional async/wrap still exists for now +#[implement(Service)] +pub async fn search_pdu_ids(&self, query: &RoomQuery<'_>) -> Result + Send + '_> { + let shortroomid = self.services.short.get_shortroomid(query.room_id).await?; - Some((results, words)) + let pdu_ids = self.search_pdu_ids_query_room(query, shortroomid).await; + + let iters = pdu_ids.into_iter().map(IntoIterator::into_iter); + + Ok(set::intersection(iters).stream()) +} + +#[implement(Service)] +async fn search_pdu_ids_query_room(&self, query: &RoomQuery<'_>, shortroomid: ShortRoomId) -> Vec> { + tokenize(&query.criteria.search_term) + .stream() + .then(|word| async move { + self.search_pdu_ids_query_words(shortroomid, &word) + .collect::>() + .await + }) + .collect::>() + .await +} + +/// Iterate over PduId's containing a word +#[implement(Service)] +fn search_pdu_ids_query_words<'a>( + &'a self, shortroomid: ShortRoomId, word: &'a str, +) -> impl Stream + Send + '_ { + self.search_pdu_ids_query_word(shortroomid, word) + .ready_filter_map(move |key| { + key[prefix_len(word)..] + .chunks_exact(PduId::LEN) + .next() + .map(RawPduId::try_from) + .and_then(Result::ok) + }) +} + +/// Iterate over raw database results for a word +#[implement(Service)] +fn search_pdu_ids_query_word(&self, shortroomid: ShortRoomId, word: &str) -> impl Stream> + Send + '_ { + const PDUID_LEN: usize = PduId::LEN; + // rustc says const'ing this not yet stable + let end_id: ArrayVec = iter::repeat(u8::MAX).take(PduId::LEN).collect(); + + // Newest pdus first + let end = make_tokenid(shortroomid, word, end_id.as_slice()); + let prefix = make_prefix(shortroomid, word); + self.db + .tokenids + .rev_raw_keys_from(&end) + .ignore_err() + .ready_take_while(move |key| key.starts_with(&prefix)) } /// Splits a string into tokens used as keys in the search inverted index @@ -119,6 +191,28 @@ pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option fn tokenize(body: &str) -> impl Iterator + Send + '_ { body.split_terminator(|c: char| !c.is_alphanumeric()) .filter(|s| !s.is_empty()) - .filter(|word| word.len() <= 50) + .filter(|word| word.len() <= WORD_MAX_LEN) .map(str::to_lowercase) } + +fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &[u8]) -> TokenId { + debug_assert!(pdu_id.len() == PduId::LEN, "pdu_id size mismatch"); + + let mut key = make_prefix(shortroomid, word); + key.extend_from_slice(pdu_id); + key +} + +fn make_prefix(shortroomid: ShortRoomId, word: &str) -> TokenId { + let mut key = TokenId::new(); + key.extend_from_slice(&shortroomid.to_be_bytes()); + key.extend_from_slice(word.as_bytes()); + key.push(database::SEP); + key +} + +fn prefix_len(word: &str) -> usize { + size_of::() + .saturating_add(word.len()) + .saturating_add(1) +}