refactor search system

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-10-26 22:21:23 +00:00
parent f245389c02
commit 21a67513f2
4 changed files with 312 additions and 205 deletions

1
Cargo.lock generated
View file

@ -786,6 +786,7 @@ dependencies = [
name = "conduit_service"
version = "0.5.0"
dependencies = [
"arrayvec",
"async-trait",
"base64 0.22.1",
"bytes",

View file

@ -2,25 +2,32 @@ use std::collections::BTreeMap;
use axum::extract::State;
use conduit::{
debug,
utils::{IterStream, ReadyExt},
Err,
at, is_true,
result::FlatOk,
utils::{stream::ReadyExt, IterStream},
Err, PduEvent, Result,
};
use futures::{FutureExt, StreamExt};
use futures::{future::OptionFuture, FutureExt, StreamExt, TryFutureExt};
use ruma::{
api::client::{
error::ErrorKind,
search::search_events::{
self,
v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult},
},
api::client::search::search_events::{
self,
v3::{Criteria, EventContextResult, ResultCategories, ResultRoomEvents, SearchResult},
},
events::AnyStateEvent,
serde::Raw,
uint, OwnedRoomId,
OwnedRoomId, RoomId, UInt, UserId,
};
use search_events::v3::{Request, Response};
use service::{rooms::search::RoomQuery, Services};
use crate::{Error, Result, Ruma};
use crate::Ruma;
type RoomStates = BTreeMap<OwnedRoomId, RoomState>;
type RoomState = Vec<Raw<AnyStateEvent>>;
const LIMIT_DEFAULT: usize = 10;
const LIMIT_MAX: usize = 100;
const BATCH_MAX: usize = 20;
/// # `POST /_matrix/client/r0/search`
///
@ -28,173 +35,177 @@ use crate::{Error, Result, Ruma};
///
/// - Only works if the user is currently joined to the room (TODO: Respect
/// history visibility)
pub(crate) async fn search_events_route(
State(services): State<crate::State>, body: Ruma<search_events::v3::Request>,
) -> Result<search_events::v3::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated");
pub(crate) async fn search_events_route(State(services): State<crate::State>, body: Ruma<Request>) -> Result<Response> {
let sender_user = body.sender_user();
let next_batch = body.next_batch.as_deref();
let room_events_result: OptionFuture<_> = body
.search_categories
.room_events
.as_ref()
.map(|criteria| category_room_events(&services, sender_user, next_batch, criteria))
.into();
let search_criteria = body.search_categories.room_events.as_ref().unwrap();
let filter = &search_criteria.filter;
let include_state = &search_criteria.include_state;
Ok(Response {
search_categories: ResultCategories {
room_events: room_events_result
.await
.unwrap_or_else(|| Ok(ResultRoomEvents::default()))?,
},
})
}
let room_ids = if let Some(room_ids) = &filter.rooms {
room_ids.clone()
} else {
services
.rooms
.state_cache
.rooms_joined(sender_user)
.map(ToOwned::to_owned)
.collect()
.await
};
#[allow(clippy::map_unwrap_or)]
async fn category_room_events(
services: &Services, sender_user: &UserId, next_batch: Option<&str>, criteria: &Criteria,
) -> Result<ResultRoomEvents> {
let filter = &criteria.filter;
// Use limit or else 10, with maximum 100
let limit: usize = filter
.limit
.unwrap_or_else(|| uint!(10))
.try_into()
.unwrap_or(10)
.min(100);
.map(TryInto::try_into)
.flat_ok()
.unwrap_or(LIMIT_DEFAULT)
.min(LIMIT_MAX);
let mut room_states: BTreeMap<OwnedRoomId, Vec<Raw<AnyStateEvent>>> = BTreeMap::new();
let next_batch: usize = next_batch
.map(str::parse)
.transpose()?
.unwrap_or(0)
.min(limit.saturating_mul(BATCH_MAX));
if include_state.is_some_and(|include_state| include_state) {
for room_id in &room_ids {
if !services
.rooms
.state_cache
.is_joined(sender_user, room_id)
.await
{
return Err!(Request(Forbidden("You don't have permission to view this room.")));
}
// check if sender_user can see state events
if services
.rooms
.state_accessor
.user_can_see_state_events(sender_user, room_id)
.await
{
let room_state: Vec<_> = services
.rooms
.state_accessor
.room_state_full(room_id)
.await?
.values()
.map(|pdu| pdu.to_state_event())
.collect();
debug!("Room state: {:?}", room_state);
room_states.insert(room_id.clone(), room_state);
} else {
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You don't have permission to view this room.",
));
}
}
}
let mut search_vecs = Vec::new();
for room_id in &room_ids {
if !services
.rooms
.state_cache
.is_joined(sender_user, room_id)
.await
{
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"You don't have permission to view this room.",
));
}
if let Some(search) = services
.rooms
.search
.search_pdus(room_id, &search_criteria.search_term)
.await
{
search_vecs.push(search.0);
}
}
let mut searches: Vec<_> = search_vecs
.iter()
.map(|vec| vec.iter().peekable())
.collect();
let skip: usize = match body.next_batch.as_ref().map(|s| s.parse()) {
Some(Ok(s)) => s,
Some(Err(_)) => return Err(Error::BadRequest(ErrorKind::InvalidParam, "Invalid next_batch token.")),
None => 0, // Default to the start
};
let mut results = Vec::new();
let next_batch = skip.saturating_add(limit);
for _ in 0..next_batch {
if let Some(s) = searches
.iter_mut()
.map(|s| (s.peek().copied(), s))
.max_by_key(|(peek, _)| *peek)
.and_then(|(_, i)| i.next())
{
results.push(s);
}
}
let results: Vec<_> = results
.into_iter()
.skip(skip)
.stream()
.filter_map(|id| services.rooms.timeline.get_pdu_from_id(id).map(Result::ok))
.ready_filter(|pdu| !pdu.is_redacted())
.filter_map(|pdu| async move {
let rooms = filter
.rooms
.clone()
.map(IntoIterator::into_iter)
.map(IterStream::stream)
.map(StreamExt::boxed)
.unwrap_or_else(|| {
services
.rooms
.state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
.state_cache
.rooms_joined(sender_user)
.map(ToOwned::to_owned)
.boxed()
});
let results: Vec<_> = rooms
.filter_map(|room_id| async move {
check_room_visible(services, sender_user, &room_id, criteria)
.await
.then_some(pdu)
.is_ok()
.then_some(room_id)
})
.take(limit)
.map(|pdu| pdu.to_room_event())
.map(|result| SearchResult {
context: EventContextResult {
end: None,
events_after: Vec::new(),
events_before: Vec::new(),
profile_info: BTreeMap::new(),
start: None,
},
rank: None,
result: Some(result),
.filter_map(|room_id| async move {
let query = RoomQuery {
room_id: &room_id,
user_id: Some(sender_user),
criteria,
skip: next_batch,
limit,
};
let (count, results) = services.rooms.search.search_pdus(&query).await.ok()?;
results
.collect::<Vec<_>>()
.map(|results| (room_id.clone(), count, results))
.map(Some)
.await
})
.collect()
.boxed()
.await;
let more_unloaded_results = searches.iter_mut().any(|s| s.peek().is_some());
let total: UInt = results
.iter()
.fold(0, |a: usize, (_, count, _)| a.saturating_add(*count))
.try_into()?;
let next_batch = more_unloaded_results.then(|| next_batch.to_string());
let state: RoomStates = results
.iter()
.stream()
.ready_filter(|_| criteria.include_state.is_some_and(is_true!()))
.filter_map(|(room_id, ..)| async move {
procure_room_state(services, room_id)
.map_ok(|state| (room_id.clone(), state))
.await
.ok()
})
.collect()
.await;
Ok(search_events::v3::Response::new(ResultCategories {
room_events: ResultRoomEvents {
count: Some(results.len().try_into().unwrap_or_else(|_| uint!(0))),
groups: BTreeMap::new(), // TODO
next_batch,
results,
state: room_states,
highlights: search_criteria
.search_term
.split_terminator(|c: char| !c.is_alphanumeric())
.map(str::to_lowercase)
.collect(),
},
}))
let results: Vec<SearchResult> = results
.into_iter()
.map(at!(2))
.flatten()
.stream()
.map(|pdu| pdu.to_room_event())
.map(|result| SearchResult {
rank: None,
result: Some(result),
context: EventContextResult {
profile_info: BTreeMap::new(), //TODO
events_after: Vec::new(), //TODO
events_before: Vec::new(), //TODO
start: None, //TODO
end: None, //TODO
},
})
.collect()
.await;
let highlights = criteria
.search_term
.split_terminator(|c: char| !c.is_alphanumeric())
.map(str::to_lowercase)
.collect();
let next_batch = (results.len() >= limit)
.then_some(next_batch.saturating_add(results.len()))
.as_ref()
.map(ToString::to_string);
Ok(ResultRoomEvents {
count: Some(total),
next_batch,
results,
state,
highlights,
groups: BTreeMap::new(), // TODO
})
}
async fn procure_room_state(services: &Services, room_id: &RoomId) -> Result<RoomState> {
let state_map = services
.rooms
.state_accessor
.room_state_full(room_id)
.await?;
let state_events = state_map
.values()
.map(AsRef::as_ref)
.map(PduEvent::to_state_event)
.collect();
Ok(state_events)
}
async fn check_room_visible(services: &Services, user_id: &UserId, room_id: &RoomId, search: &Criteria) -> Result {
let check_visible = search.filter.rooms.is_some();
let check_state = check_visible && search.include_state.is_some_and(is_true!());
let is_joined = !check_visible || services.rooms.state_cache.is_joined(user_id, room_id).await;
let state_visible = !check_state
|| services
.rooms
.state_accessor
.user_can_see_state_events(user_id, room_id)
.await;
if !is_joined || !state_visible {
return Err!(Request(Forbidden("You don't have permission to view {room_id:?}")));
}
Ok(())
}

View file

@ -40,6 +40,7 @@ release_max_log_level = [
]
[dependencies]
arrayvec.workspace = true
async-trait.workspace = true
base64.workspace = true
bytes.workspace = true

View file

@ -1,15 +1,23 @@
use std::sync::Arc;
use std::{iter, sync::Arc};
use arrayvec::ArrayVec;
use conduit::{
implement,
utils::{set, stream::TryIgnore, IterStream, ReadyExt},
Result,
utils::{set, stream::TryIgnore, ArrayVecExt, IterStream, ReadyExt},
PduEvent, Result,
};
use database::Map;
use futures::StreamExt;
use ruma::RoomId;
use database::{keyval::Val, Map};
use futures::{Stream, StreamExt};
use ruma::{api::client::search::search_events::v3::Criteria, RoomId, UserId};
use crate::{rooms, Dep};
use crate::{
rooms,
rooms::{
short::ShortRoomId,
timeline::{PduId, RawPduId},
},
Dep,
};
pub struct Service {
db: Data,
@ -22,8 +30,24 @@ struct Data {
struct Services {
short: Dep<rooms::short::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
timeline: Dep<rooms::timeline::Service>,
}
#[derive(Clone, Debug)]
pub struct RoomQuery<'a> {
pub room_id: &'a RoomId,
pub user_id: Option<&'a UserId>,
pub criteria: &'a Criteria,
pub limit: usize,
pub skip: usize,
}
type TokenId = ArrayVec<u8, TOKEN_ID_MAX_LEN>;
const TOKEN_ID_MAX_LEN: usize = size_of::<ShortRoomId>() + WORD_MAX_LEN + 1 + size_of::<RawPduId>();
const WORD_MAX_LEN: usize = 50;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
@ -32,6 +56,8 @@ impl crate::Service for Service {
},
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
}))
}
@ -70,46 +96,92 @@ pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
}
#[implement(Service)]
pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option<(Vec<Vec<u8>>, Vec<String>)> {
let prefix = self
.services
.short
.get_shortroomid(room_id)
.await
.ok()?
.to_be_bytes()
.to_vec();
pub async fn search_pdus<'a>(
&'a self, query: &'a RoomQuery<'a>,
) -> Result<(usize, impl Stream<Item = PduEvent> + Send + 'a)> {
let pdu_ids: Vec<_> = self.search_pdu_ids(query).await?.collect().await;
let words: Vec<_> = tokenize(search_string).collect();
let bufs: Vec<_> = words
.clone()
let count = pdu_ids.len();
let pdus = pdu_ids
.into_iter()
.stream()
.then(move |word| {
let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xFF);
let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
self.db.tokenids
.rev_raw_keys_from(&last_possible_id) // Newest pdus first
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix2))
.map(move |key| key[prefix3.len()..].to_vec())
.collect::<Vec<_>>()
.filter_map(move |result_pdu_id: RawPduId| async move {
self.services
.timeline
.get_pdu_from_id(&result_pdu_id)
.await
.ok()
})
.collect()
.await;
.ready_filter(|pdu| !pdu.is_redacted())
.filter_map(move |pdu| async move {
self.services
.state_accessor
.user_can_see_event(query.user_id?, &pdu.room_id, &pdu.event_id)
.await
.then_some(pdu)
})
.skip(query.skip)
.take(query.limit);
let bufs = bufs.iter().map(|buf| buf.iter());
Ok((count, pdus))
}
let results = set::intersection(bufs).cloned().collect();
// result is modeled as a stream such that callers don't have to be refactored
// though an additional async/wrap still exists for now
#[implement(Service)]
pub async fn search_pdu_ids(&self, query: &RoomQuery<'_>) -> Result<impl Stream<Item = RawPduId> + Send + '_> {
let shortroomid = self.services.short.get_shortroomid(query.room_id).await?;
Some((results, words))
let pdu_ids = self.search_pdu_ids_query_room(query, shortroomid).await;
let iters = pdu_ids.into_iter().map(IntoIterator::into_iter);
Ok(set::intersection(iters).stream())
}
#[implement(Service)]
async fn search_pdu_ids_query_room(&self, query: &RoomQuery<'_>, shortroomid: ShortRoomId) -> Vec<Vec<RawPduId>> {
tokenize(&query.criteria.search_term)
.stream()
.then(|word| async move {
self.search_pdu_ids_query_words(shortroomid, &word)
.collect::<Vec<_>>()
.await
})
.collect::<Vec<_>>()
.await
}
/// Iterate over PduId's containing a word
#[implement(Service)]
fn search_pdu_ids_query_words<'a>(
&'a self, shortroomid: ShortRoomId, word: &'a str,
) -> impl Stream<Item = RawPduId> + Send + '_ {
self.search_pdu_ids_query_word(shortroomid, word)
.ready_filter_map(move |key| {
key[prefix_len(word)..]
.chunks_exact(PduId::LEN)
.next()
.map(RawPduId::try_from)
.and_then(Result::ok)
})
}
/// Iterate over raw database results for a word
#[implement(Service)]
fn search_pdu_ids_query_word(&self, shortroomid: ShortRoomId, word: &str) -> impl Stream<Item = Val<'_>> + Send + '_ {
const PDUID_LEN: usize = PduId::LEN;
// rustc says const'ing this not yet stable
let end_id: ArrayVec<u8, PDUID_LEN> = iter::repeat(u8::MAX).take(PduId::LEN).collect();
// Newest pdus first
let end = make_tokenid(shortroomid, word, end_id.as_slice());
let prefix = make_prefix(shortroomid, word);
self.db
.tokenids
.rev_raw_keys_from(&end)
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix))
}
/// Splits a string into tokens used as keys in the search inverted index
@ -119,6 +191,28 @@ pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option
fn tokenize(body: &str) -> impl Iterator<Item = String> + Send + '_ {
body.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50)
.filter(|word| word.len() <= WORD_MAX_LEN)
.map(str::to_lowercase)
}
fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &[u8]) -> TokenId {
debug_assert!(pdu_id.len() == PduId::LEN, "pdu_id size mismatch");
let mut key = make_prefix(shortroomid, word);
key.extend_from_slice(pdu_id);
key
}
fn make_prefix(shortroomid: ShortRoomId, word: &str) -> TokenId {
let mut key = TokenId::new();
key.extend_from_slice(&shortroomid.to_be_bytes());
key.extend_from_slice(word.as_bytes());
key.push(database::SEP);
key
}
fn prefix_len(word: &str) -> usize {
size_of::<ShortRoomId>()
.saturating_add(word.len())
.saturating_add(1)
}