diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 5b492cb1..d07f6ac1 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -2,7 +2,7 @@ use std::iter::once; use axum::extract::State; use conduit::{ - err, error, + at, err, error, utils::{future::TryExtExt, stream::ReadyExt, IterStream}, Err, Result, }; @@ -82,7 +82,7 @@ pub(crate) async fn get_context_route( let events_before: Vec<_> = services .rooms .timeline - .pdus_until(sender_user, room_id, base_token) + .pdus_rev(sender_user, room_id, base_token.saturating_sub(1)) .await? .ready_filter_map(|item| event_filter(item, filter)) .filter_map(|item| ignored_filter(&services, item, sender_user)) @@ -94,7 +94,7 @@ pub(crate) async fn get_context_route( let events_after: Vec<_> = services .rooms .timeline - .pdus_after(sender_user, room_id, base_token) + .pdus(sender_user, room_id, base_token.saturating_add(1)) .await? .ready_filter_map(|item| event_filter(item, filter)) .filter_map(|item| ignored_filter(&services, item, sender_user)) @@ -168,22 +168,28 @@ pub(crate) async fn get_context_route( start: events_before .last() - .map_or_else(|| base_token.to_string(), |(count, _)| count.to_string()) - .into(), + .map(at!(0)) + .map(|count| count.saturating_sub(1)) + .as_ref() + .map(ToString::to_string), end: events_after .last() - .map_or_else(|| base_token.to_string(), |(count, _)| count.to_string()) - .into(), + .map(at!(0)) + .map(|count| count.saturating_add(1)) + .as_ref() + .map(ToString::to_string), events_before: events_before .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) + .map(at!(1)) + .map(|pdu| pdu.to_room_event()) .collect(), events_after: events_after .into_iter() - .map(|(_, pdu)| pdu.to_room_event()) + .map(at!(1)) + .map(|pdu| pdu.to_room_event()) .collect(), state, diff --git a/src/api/client/message.rs b/src/api/client/message.rs index cb261a7f..e76325aa 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -100,14 +100,14 @@ pub(crate) async fn get_message_events_route( Direction::Forward => services .rooms .timeline - .pdus_after(sender_user, room_id, from) + .pdus(sender_user, room_id, from) .await? .boxed(), Direction::Backward => services .rooms .timeline - .pdus_until(sender_user, room_id, from) + .pdus_rev(sender_user, room_id, from) .await? .boxed(), }; @@ -136,7 +136,12 @@ pub(crate) async fn get_message_events_route( .collect() .await; - let next_token = events.last().map(|(count, _)| count).copied(); + let start_token = events.first().map(at!(0)).unwrap_or(from); + + let next_token = events + .last() + .map(at!(0)) + .map(|count| count.saturating_inc(body.dir)); if !cfg!(feature = "element_hacks") { if let Some(next_token) = next_token { @@ -154,8 +159,8 @@ pub(crate) async fn get_message_events_route( .collect(); Ok(get_message_events::v3::Response { - start: from.to_string(), - end: next_token.as_ref().map(PduCount::to_string), + start: start_token.to_string(), + end: next_token.as_ref().map(ToString::to_string), chunk, state, }) diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index b5d1485b..ee62dbfc 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -150,10 +150,7 @@ async fn paginate_relations_with_filter( Direction::Backward => events.first(), } .map(at!(0)) - .map(|count| match dir { - Direction::Forward => count.saturating_add(1), - Direction::Backward => count.saturating_sub(1), - }) + .map(|count| count.saturating_inc(dir)) .as_ref() .map(ToString::to_string); diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs index 7aec7186..f047d176 100644 --- a/src/api/client/sync/mod.rs +++ b/src/api/client/sync/mod.rs @@ -24,7 +24,7 @@ async fn load_timeline( let mut non_timeline_pdus = services .rooms .timeline - .pdus_until(sender_user, room_id, PduCount::max()) + .pdus_rev(sender_user, room_id, PduCount::max()) .await? .ready_take_while(|(pducount, _)| *pducount > roomsincecount); diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 00976c78..ea487d8e 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -6,7 +6,7 @@ use std::{ use axum::extract::State; use conduit::{ - err, error, extract_variant, is_equal_to, + at, err, error, extract_variant, is_equal_to, result::FlatOk, utils::{math::ruma_from_u64, BoolExt, IterStream, ReadyExt, TryFutureExtExt}, PduCount, @@ -945,15 +945,10 @@ async fn load_joined_room( let prev_batch = timeline_pdus .first() - .map_or(Ok::<_, Error>(None), |(pdu_count, _)| { - Ok(Some(match pdu_count { - PduCount::Backfilled(_) => { - error!("timeline in backfill state?!"); - "0".to_owned() - }, - PduCount::Normal(c) => c.to_string(), - })) - })?; + .map(at!(0)) + .map(|count| count.saturating_sub(1)) + .as_ref() + .map(ToString::to_string); let room_events: Vec<_> = timeline_pdus .iter() diff --git a/src/api/server/backfill.rs b/src/api/server/backfill.rs index 281bf2a2..47f02841 100644 --- a/src/api/server/backfill.rs +++ b/src/api/server/backfill.rs @@ -51,7 +51,7 @@ pub(crate) async fn get_backfill_route( let pdus = services .rooms .timeline - .pdus_until(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until) + .pdus_rev(user_id!("@doesntmatter:conduit.rs"), &body.room_id, until) .await? .take(limit) .filter_map(|(_, pdu)| async move { diff --git a/src/core/pdu/count.rs b/src/core/pdu/count.rs index 90e552e8..aceec1e8 100644 --- a/src/core/pdu/count.rs +++ b/src/core/pdu/count.rs @@ -2,6 +2,8 @@ use std::{cmp::Ordering, fmt, fmt::Display, str::FromStr}; +use ruma::api::Direction; + use crate::{err, Error, Result}; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] @@ -54,6 +56,14 @@ impl PduCount { } } + #[inline] + pub fn checked_inc(self, dir: Direction) -> Result { + match dir { + Direction::Forward => self.checked_add(1), + Direction::Backward => self.checked_sub(1), + } + } + #[inline] pub fn checked_add(self, add: u64) -> Result { Ok(match self { @@ -82,6 +92,15 @@ impl PduCount { }) } + #[inline] + #[must_use] + pub fn saturating_inc(self, dir: Direction) -> Self { + match dir { + Direction::Forward => self.saturating_add(1), + Direction::Backward => self.saturating_sub(1), + } + } + #[inline] #[must_use] pub fn saturating_add(self, add: u64) -> Self { diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index f062e7e4..f320e6a0 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -62,7 +62,7 @@ impl Data { { hash_map::Entry::Occupied(o) => Ok(*o.get()), hash_map::Entry::Vacant(v) => Ok(self - .pdus_until(sender_user, room_id, PduCount::max()) + .pdus_rev(sender_user, room_id, PduCount::max()) .await? .next() .await @@ -201,10 +201,10 @@ impl Data { /// Returns an iterator over all events and their tokens in a room that /// happened before the event with id `until` in reverse-chronological /// order. - pub(super) async fn pdus_until<'a>( + pub(super) async fn pdus_rev<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, ) -> Result + Send + 'a> { - let current = self.count_to_id(room_id, until, true).await?; + let current = self.count_to_id(room_id, until).await?; let prefix = current.shortroomid(); let stream = self .pduid_pdu @@ -216,10 +216,10 @@ impl Data { Ok(stream) } - pub(super) async fn pdus_after<'a>( + pub(super) async fn pdus<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, ) -> Result + Send + 'a> { - let current = self.count_to_id(room_id, from, false).await?; + let current = self.count_to_id(room_id, from).await?; let prefix = current.shortroomid(); let stream = self .pduid_pdu @@ -266,7 +266,7 @@ impl Data { } } - async fn count_to_id(&self, room_id: &RoomId, count: PduCount, subtract: bool) -> Result { + async fn count_to_id(&self, room_id: &RoomId, shorteventid: PduCount) -> Result { let shortroomid: ShortRoomId = self .services .short @@ -277,11 +277,7 @@ impl Data { // +1 so we don't send the base event let pdu_id = PduId { shortroomid, - shorteventid: if subtract { - count.checked_sub(1)? - } else { - count.checked_add(1)? - }, + shorteventid, }; Ok(pdu_id.into()) diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index 8255be7d..81d372d7 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -177,7 +177,7 @@ impl Service { #[tracing::instrument(skip(self), level = "debug")] pub async fn latest_pdu_in_room(&self, room_id: &RoomId) -> Result> { - self.pdus_until(user_id!("@placeholder:conduwuit.placeholder"), room_id, PduCount::max()) + self.pdus_rev(user_id!("@placeholder:conduwuit.placeholder"), room_id, PduCount::max()) .await? .next() .await @@ -976,26 +976,23 @@ impl Service { pub async fn all_pdus<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, ) -> Result + Send + 'a> { - self.pdus_after(user_id, room_id, PduCount::min()).await + self.pdus(user_id, room_id, PduCount::min()).await } - /// Returns an iterator over all events and their tokens in a room that - /// happened before the event with id `until` in reverse-chronological - /// order. + /// Reverse iteration starting at from. #[tracing::instrument(skip(self), level = "debug")] - pub async fn pdus_until<'a>( + pub async fn pdus_rev<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, ) -> Result + Send + 'a> { - self.db.pdus_until(user_id, room_id, until).await + self.db.pdus_rev(user_id, room_id, until).await } - /// Returns an iterator over all events and their token in a room that - /// happened after the event with id `from` in chronological order. + /// Forward iteration starting at from. #[tracing::instrument(skip(self), level = "debug")] - pub async fn pdus_after<'a>( + pub async fn pdus<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, ) -> Result + Send + 'a> { - self.db.pdus_after(user_id, room_id, from).await + self.db.pdus(user_id, room_id, from).await } /// Replace a PDU with the redacted form.