feat: recurse relationships (and fix some lints)

from https://gitlab.com/famedly/conduit/-/merge_requests/613

Signed-off-by: strawberry <strawberry@puppygock.gay>
This commit is contained in:
Matthias Ahouansou 2024-04-03 14:10:00 -04:00 committed by June
parent 661dba688a
commit ed960f41ac
2 changed files with 101 additions and 99 deletions

View file

@ -2,7 +2,7 @@ use ruma::api::client::relations::{
get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type,
}; };
use crate::{service::rooms::timeline::PduCount, services, Result, Ruma}; use crate::{services, Result, Ruma};
/// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}`
pub async fn get_relating_events_with_rel_type_and_event_type_route( pub async fn get_relating_events_with_rel_type_and_event_type_route(
@ -10,26 +10,6 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> { ) -> Result<get_relating_events_with_rel_type_and_event_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?,
None => match body.dir {
ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(),
},
};
let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = body
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
let res = services() let res = services()
.rooms .rooms
.pdu_metadata .pdu_metadata
@ -39,17 +19,18 @@ pub async fn get_relating_events_with_rel_type_and_event_type_route(
&body.event_id, &body.event_id,
&Some(body.event_type.clone()), &Some(body.event_type.clone()),
&Some(body.rel_type.clone()), &Some(body.rel_type.clone()),
from, &body.from,
&body.to,
&body.limit,
body.recurse,
body.dir, body.dir,
to,
limit,
)?; )?;
Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { Ok(get_relating_events_with_rel_type_and_event_type::v1::Response {
chunk: res.chunk, chunk: res.chunk,
next_batch: res.next_batch, next_batch: res.next_batch,
prev_batch: res.prev_batch, prev_batch: res.prev_batch,
recursion_depth: None, // TODO recursion_depth: res.recursion_depth,
}) })
} }
@ -59,26 +40,6 @@ pub async fn get_relating_events_with_rel_type_route(
) -> Result<get_relating_events_with_rel_type::v1::Response> { ) -> Result<get_relating_events_with_rel_type::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?,
None => match body.dir {
ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(),
},
};
let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = body
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
let res = services() let res = services()
.rooms .rooms
.pdu_metadata .pdu_metadata
@ -88,17 +49,18 @@ pub async fn get_relating_events_with_rel_type_route(
&body.event_id, &body.event_id,
&None, &None,
&Some(body.rel_type.clone()), &Some(body.rel_type.clone()),
from, &body.from,
&body.to,
&body.limit,
body.recurse,
body.dir, body.dir,
to,
limit,
)?; )?;
Ok(get_relating_events_with_rel_type::v1::Response { Ok(get_relating_events_with_rel_type::v1::Response {
chunk: res.chunk, chunk: res.chunk,
next_batch: res.next_batch, next_batch: res.next_batch,
prev_batch: res.prev_batch, prev_batch: res.prev_batch,
recursion_depth: None, // TODO recursion_depth: res.recursion_depth,
}) })
} }
@ -108,26 +70,6 @@ pub async fn get_relating_events_route(
) -> Result<get_relating_events::v1::Response> { ) -> Result<get_relating_events::v1::Response> {
let sender_user = body.sender_user.as_ref().expect("user is authenticated"); let sender_user = body.sender_user.as_ref().expect("user is authenticated");
let from = match body.from.clone() {
Some(from) => PduCount::try_from_string(&from)?,
None => match body.dir {
ruma::api::Direction::Forward => PduCount::min(),
ruma::api::Direction::Backward => PduCount::max(),
},
};
let to = body
.to
.as_ref()
.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = body
.limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
services() services()
.rooms .rooms
.pdu_metadata .pdu_metadata
@ -137,9 +79,10 @@ pub async fn get_relating_events_route(
&body.event_id, &body.event_id,
&None, &None,
&None, &None,
from, &body.from,
&body.to,
&body.limit,
body.recurse,
body.dir, body.dir,
to,
limit,
) )
} }

View file

@ -5,7 +5,7 @@ pub use data::Data;
use ruma::{ use ruma::{
api::{client::relations::get_relating_events, Direction}, api::{client::relations::get_relating_events, Direction},
events::{relation::RelationType, TimelineEventType}, events::{relation::RelationType, TimelineEventType},
EventId, RoomId, UserId, EventId, RoomId, UInt, UserId,
}; };
use serde::Deserialize; use serde::Deserialize;
@ -42,18 +42,44 @@ impl Service {
#[allow(clippy::too_many_arguments)] #[allow(clippy::too_many_arguments)]
pub fn paginate_relations_with_filter( pub fn paginate_relations_with_filter(
&self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option<TimelineEventType>, &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option<TimelineEventType>,
filter_rel_type: &Option<RelationType>, from: PduCount, dir: Direction, to: Option<PduCount>, limit: usize, filter_rel_type: &Option<RelationType>, from: &Option<String>, to: &Option<String>, limit: &Option<UInt>,
recurse: bool, dir: Direction,
) -> Result<get_relating_events::v1::Response> { ) -> Result<get_relating_events::v1::Response> {
let from = match from {
Some(from) => PduCount::try_from_string(from)?,
None => match dir {
Direction::Forward => PduCount::min(),
Direction::Backward => PduCount::max(),
},
};
let to = to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = limit
.and_then(|u| u32::try_from(u).ok())
.map_or(10_usize, |u| u as usize)
.min(100);
let next_token; let next_token;
// Spec (v1.10) recommends depth of at least 3
let depth: u8 = if recurse {
3
} else {
1
};
match dir { match dir {
Direction::Forward => { Direction::Forward => {
let events_after: Vec<_> = services() let relations_until =
&services()
.rooms .rooms
.pdu_metadata .pdu_metadata
.relations_until(sender_user, room_id, target, from)? // TODO: should be relations_after .relations_until(sender_user, room_id, target, from, depth)?;
.filter(|r| { let events_after: Vec<_> = relations_until // TODO: should be relations_after
r.as_ref().map_or(true, |(_, pdu)| { .iter()
.filter(|(_, pdu)| {
filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t)
&& if let Ok(content) = && if let Ok(content) =
serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get())
@ -65,9 +91,7 @@ impl Service {
false false
} }
}) })
})
.take(limit) .take(limit)
.filter_map(Result::ok) // Filter out buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services()
.rooms .rooms
@ -75,7 +99,7 @@ impl Service {
.user_can_see_event(sender_user, room_id, &pdu.event_id) .user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
}) })
.take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take_while(|(k, _)| Some(k) != to.as_ref()) // Stop at `to`
.collect(); .collect();
next_token = events_after.last().map(|(count, _)| count).copied(); next_token = events_after.last().map(|(count, _)| count).copied();
@ -90,16 +114,22 @@ impl Service {
chunk: events_after, chunk: events_after,
next_batch: next_token.map(|t| t.stringify()), next_batch: next_token.map(|t| t.stringify()),
prev_batch: Some(from.stringify()), prev_batch: Some(from.stringify()),
recursion_depth: None, // TODO recursion_depth: if recurse {
Some(depth.into())
} else {
None
},
}) })
}, },
Direction::Backward => { Direction::Backward => {
let events_before: Vec<_> = services() let relations_until =
&services()
.rooms .rooms
.pdu_metadata .pdu_metadata
.relations_until(sender_user, room_id, target, from)? .relations_until(sender_user, room_id, target, from, depth)?;
.filter(|r| { let events_before: Vec<_> = relations_until
r.as_ref().map_or(true, |(_, pdu)| { .iter()
.filter(|(_, pdu)| {
filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t) filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t)
&& if let Ok(content) = && if let Ok(content) =
serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get())
@ -111,9 +141,7 @@ impl Service {
false false
} }
}) })
})
.take(limit) .take(limit)
.filter_map(Result::ok) // Filter out buggy events
.filter(|(_, pdu)| { .filter(|(_, pdu)| {
services() services()
.rooms .rooms
@ -121,7 +149,7 @@ impl Service {
.user_can_see_event(sender_user, room_id, &pdu.event_id) .user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false) .unwrap_or(false)
}) })
.take_while(|&(k, _)| Some(k) != to) // Stop at `to` .take_while(|&(k, _)| Some(k) != to.as_ref()) // Stop at `to`
.collect(); .collect();
next_token = events_before.last().map(|(count, _)| count).copied(); next_token = events_before.last().map(|(count, _)| count).copied();
@ -135,15 +163,19 @@ impl Service {
chunk: events_before, chunk: events_before,
next_batch: next_token.map(|t| t.stringify()), next_batch: next_token.map(|t| t.stringify()),
prev_batch: Some(from.stringify()), prev_batch: Some(from.stringify()),
recursion_depth: None, // TODO recursion_depth: if recurse {
Some(depth.into())
} else {
None
},
}) })
}, },
} }
} }
pub fn relations_until<'a>( pub fn relations_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, &'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8,
) -> Result<impl Iterator<Item = Result<(PduCount, PduEvent)>> + 'a> { ) -> Result<Vec<(PduCount, PduEvent)>> {
let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?; let room_id = services().rooms.short.get_or_create_shortroomid(room_id)?;
#[allow(unknown_lints)] #[allow(unknown_lints)]
#[allow(clippy::manual_unwrap_or_default)] #[allow(clippy::manual_unwrap_or_default)]
@ -152,7 +184,34 @@ impl Service {
// TODO: Support backfilled relations // TODO: Support backfilled relations
_ => 0, // This will result in an empty iterator _ => 0, // This will result in an empty iterator
}; };
self.db.relations_until(user_id, room_id, target, until)
self.db
.relations_until(user_id, room_id, target, until)
.map(|mut relations| {
let mut pdus: Vec<_> = (*relations).into_iter().filter_map(Result::ok).collect();
let mut stack: Vec<_> = pdus.clone().iter().map(|pdu| (pdu.to_owned(), 1)).collect();
while let Some(stack_pdu) = stack.pop() {
let target = match stack_pdu.0 .0 {
PduCount::Normal(c) => c,
// TODO: Support backfilled relations
PduCount::Backfilled(_) => 0, // This will result in an empty iterator
};
if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) {
for relation in relations.flatten() {
if stack_pdu.1 < max_depth {
stack.push((relation.clone(), stack_pdu.1 + 1));
}
pdus.push(relation);
}
}
}
pdus.sort_by(|a, b| a.0.partial_cmp(&b.0).expect("u64s can always be compared"));
pdus
})
} }
#[tracing::instrument(skip(self, room_id, event_ids))] #[tracing::instrument(skip(self, room_id, event_ids))]