implement include_state search criteria

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
strawberry 2024-03-24 12:53:32 -04:00 committed by June
parent c2e89b939c
commit fbefbd57be
3 changed files with 72 additions and 31 deletions

42
Cargo.lock generated
View file

@ -96,13 +96,13 @@ dependencies = [
[[package]] [[package]]
name = "async-trait" name = "async-trait"
version = "0.1.78" version = "0.1.79"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "461abc97219de0eaaf81fe3ef974a540158f3d079c2ab200f891f1a2ef201e85" checksum = "a507401cad91ec6a857ed5513a2073c82a9b9048762b885bb98655b306964681"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -256,7 +256,7 @@ dependencies = [
"regex", "regex",
"rustc-hash", "rustc-hash",
"shlex", "shlex",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -404,7 +404,7 @@ dependencies = [
"heck 0.5.0", "heck 0.5.0",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -577,7 +577,7 @@ checksum = "f46882e17999c6cc590af592290432be3bce0428cb0d5f8b6715e4dc7b383eb3"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -674,7 +674,7 @@ dependencies = [
"heck 0.4.1", "heck 0.4.1",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -799,7 +799,7 @@ checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -1810,7 +1810,7 @@ dependencies = [
"proc-macro2", "proc-macro2",
"proc-macro2-diagnostics", "proc-macro2-diagnostics",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -1884,7 +1884,7 @@ checksum = "2f38a4412a78282e09a2cf38d195ea5420d15ba0602cb375210efbc877243965"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -1979,7 +1979,7 @@ checksum = "af066a9c399a26e020ada66a034357a868728e72cd426f3adcd35f80d88d88c8"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
"version_check", "version_check",
"yansi", "yansi",
] ]
@ -2294,7 +2294,7 @@ dependencies = [
"quote", "quote",
"ruma-identifiers-validation", "ruma-identifiers-validation",
"serde", "serde",
"syn 2.0.53", "syn 2.0.55",
"toml", "toml",
] ]
@ -2532,7 +2532,7 @@ checksum = "7eb0b34b42edc17f6b7cac84a52a1c5f0e1bb2227e997ca9011ea3dd34e8610b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -2798,9 +2798,9 @@ dependencies = [
[[package]] [[package]]
name = "syn" name = "syn"
version = "2.0.53" version = "2.0.55"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7383cd0e49fff4b6b90ca5670bfd3e9d6a733b3f90c686605aa7eec8c4996032" checksum = "002a1b3dbf967edfafc32655d0f377ab0bb7b994aa1d32c8cc7e9b8bf3ebb8f0"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
@ -2862,7 +2862,7 @@ checksum = "c61f3ba182994efc43764a46c018c347bc492c79f024e705f46567b418f6d4f7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -3000,7 +3000,7 @@ checksum = "5b8a1e28f2deaa14e508979454cb3a223b10b938b45af148bc0986de36f1923b"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -3154,7 +3154,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]
@ -3380,7 +3380,7 @@ dependencies = [
"once_cell", "once_cell",
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -3414,7 +3414,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
"wasm-bindgen-backend", "wasm-bindgen-backend",
"wasm-bindgen-shared", "wasm-bindgen-shared",
] ]
@ -3692,7 +3692,7 @@ checksum = "9ce1b18ccd8e73a9321186f97e46f9f04b778851177567b1975109d26a08d2a6"
dependencies = [ dependencies = [
"proc-macro2", "proc-macro2",
"quote", "quote",
"syn 2.0.53", "syn 2.0.55",
] ]
[[package]] [[package]]

View file

@ -51,7 +51,7 @@ serde_html_form = "0.2.5"
hmac = "0.12.1" hmac = "0.12.1"
sha-1 = "0.10.1" sha-1 = "0.10.1"
async-trait = "0.1.78" async-trait = "0.1.79"
# used for checking if an IP is in specific subnets / CIDR ranges easier # used for checking if an IP is in specific subnets / CIDR ranges easier
ipaddress = "0.1.3" ipaddress = "0.1.3"

View file

@ -1,12 +1,18 @@
use std::collections::BTreeMap; use std::collections::BTreeMap;
use ruma::api::client::{ use ruma::{
error::ErrorKind, api::client::{
search::search_events::{ error::ErrorKind,
self, search::search_events::{
v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult}, self,
v3::{EventContextResult, ResultCategories, ResultRoomEvents, SearchResult},
},
}, },
events::AnyStateEvent,
serde::Raw,
OwnedRoomId,
}; };
use tracing::debug;
use crate::{services, Error, Result, Ruma}; use crate::{services, Error, Result, Ruma};
@ -21,6 +27,7 @@ pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Resu
let search_criteria = body.search_categories.room_events.as_ref().unwrap(); let search_criteria = body.search_categories.room_events.as_ref().unwrap();
let filter = &search_criteria.filter; let filter = &search_criteria.filter;
let include_state = &search_criteria.include_state;
let room_ids = filter let room_ids = filter
.rooms .rooms
@ -30,17 +37,51 @@ pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Resu
// Use limit or else 10, with maximum 100 // Use limit or else 10, with maximum 100
let limit = filter.limit.map_or(10, u64::from).min(100) as usize; let limit = filter.limit.map_or(10, u64::from).min(100) as usize;
let mut room_states: BTreeMap<OwnedRoomId, Vec<Raw<AnyStateEvent>>> = BTreeMap::new();
if include_state.is_some_and(|include_state| include_state) {
for room_id in &room_ids {
if !services().rooms.state_cache.is_joined(sender_user, room_id)? {
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"You don't have permission to view this room.",
));
}
// check if sender_user can see state events
if services().rooms.state_accessor.user_can_see_state_events(sender_user, room_id)? {
let room_state = services()
.rooms
.state_accessor
.room_state_full(room_id)
.await?
.values()
.map(|pdu| pdu.to_state_event())
.collect::<Vec<_>>();
debug!("Room state: {:?}", room_state);
room_states.insert(room_id.clone(), room_state);
} else {
return Err(Error::BadRequest(
ErrorKind::Forbidden,
"You don't have permission to view this room.",
));
}
}
}
let mut searches = Vec::new(); let mut searches = Vec::new();
for room_id in room_ids { for room_id in &room_ids {
if !services().rooms.state_cache.is_joined(sender_user, &room_id)? { if !services().rooms.state_cache.is_joined(sender_user, room_id)? {
return Err(Error::BadRequest( return Err(Error::BadRequest(
ErrorKind::Forbidden, ErrorKind::Forbidden,
"You don't have permission to view this room.", "You don't have permission to view this room.",
)); ));
} }
if let Some(search) = services().rooms.search.search_pdus(&room_id, &search_criteria.search_term)? { if let Some(search) = services().rooms.search.search_pdus(room_id, &search_criteria.search_term)? {
searches.push(search.0.peekable()); searches.push(search.0.peekable());
} }
} }
@ -114,7 +155,7 @@ pub async fn search_events_route(body: Ruma<search_events::v3::Request>) -> Resu
groups: BTreeMap::new(), // TODO groups: BTreeMap::new(), // TODO
next_batch, next_batch,
results, results,
state: BTreeMap::new(), // TODO state: room_states,
highlights: search_criteria highlights: search_criteria
.search_term .search_term
.split_terminator(|c: char| !c.is_alphanumeric()) .split_terminator(|c: char| !c.is_alphanumeric())