From badb83484fe913188f683b73e7d2afd86b0b29ba Mon Sep 17 00:00:00 2001 From: strawberry Date: Tue, 10 Dec 2024 22:54:19 -0500 Subject: [PATCH] fix private read receipt support Signed-off-by: strawberry --- src/api/client/read_marker.rs | 129 ++++++++++++------------- src/api/client/sync/v3.rs | 21 +++- src/service/rooms/read_receipt/data.rs | 6 +- src/service/rooms/read_receipt/mod.rs | 64 ++++++++++-- 4 files changed, 140 insertions(+), 80 deletions(-) diff --git a/src/api/client/read_marker.rs b/src/api/client/read_marker.rs index f28b2aec..f6123614 100644 --- a/src/api/client/read_marker.rs +++ b/src/api/client/read_marker.rs @@ -1,9 +1,9 @@ use std::collections::BTreeMap; use axum::extract::State; -use conduit::PduCount; +use conduit::{err, Err, PduCount}; use ruma::{ - api::client::{error::ErrorKind, read_marker::set_read_marker, receipt::create_receipt}, + api::client::{read_marker::set_read_marker, receipt::create_receipt}, events::{ receipt::{ReceiptThread, ReceiptType}, RoomAccountDataEventType, @@ -11,7 +11,7 @@ use ruma::{ MilliSecondsSinceUnixEpoch, }; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `POST /_matrix/client/r0/rooms/{roomId}/read_markers` /// @@ -23,14 +23,15 @@ use crate::{Error, Result, Ruma}; pub(crate) async fn set_read_marker_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user(); - if let Some(fully_read) = &body.fully_read { + if let Some(event) = &body.fully_read { let fully_read_event = ruma::events::fully_read::FullyReadEvent { content: ruma::events::fully_read::FullyReadEventContent { - event_id: fully_read.clone(), + event_id: event.clone(), }, }; + services .account_data .update( @@ -49,44 +50,20 @@ pub(crate) async fn set_read_marker_route( .reset_notification_counts(sender_user, &body.room_id); } - if let Some(event) = &body.private_read_receipt { - let count = services - .rooms - .timeline - .get_pdu_count(event) - .await - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; - - let count = match count { - PduCount::Backfilled(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Read receipt is in backfilled timeline", - )) - }, - PduCount::Normal(c) => c, - }; - services - .rooms - .read_receipt - .private_read_set(&body.room_id, sender_user, count); - } - if let Some(event) = &body.read_receipt { - let mut user_receipts = BTreeMap::new(); - user_receipts.insert( - sender_user.clone(), - ruma::events::receipt::Receipt { - ts: Some(MilliSecondsSinceUnixEpoch::now()), - thread: ReceiptThread::Unthreaded, - }, - ); - - let mut receipts = BTreeMap::new(); - receipts.insert(ReceiptType::Read, user_receipts); - - let mut receipt_content = BTreeMap::new(); - receipt_content.insert(event.to_owned(), receipts); + let receipt_content = BTreeMap::from_iter([( + event.to_owned(), + BTreeMap::from_iter([( + ReceiptType::Read, + BTreeMap::from_iter([( + sender_user.to_owned(), + ruma::events::receipt::Receipt { + ts: Some(MilliSecondsSinceUnixEpoch::now()), + thread: ReceiptThread::Unthreaded, + }, + )]), + )]), + )]); services .rooms @@ -102,6 +79,24 @@ pub(crate) async fn set_read_marker_route( .await; } + if let Some(event) = &body.private_read_receipt { + let count = services + .rooms + .timeline + .get_pdu_count(event) + .await + .map_err(|_| err!(Request(NotFound("Event not found."))))?; + + let PduCount::Normal(count) = count else { + return Err!(Request(InvalidParam("Event is a backfilled PDU and cannot be marked as read."))); + }; + + services + .rooms + .read_receipt + .private_read_set(&body.room_id, sender_user, count); + } + Ok(set_read_marker::v3::Response {}) } @@ -111,7 +106,7 @@ pub(crate) async fn set_read_marker_route( pub(crate) async fn create_receipt_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); + let sender_user = body.sender_user(); if matches!( &body.receipt_type, @@ -141,19 +136,19 @@ pub(crate) async fn create_receipt_route( .await?; }, create_receipt::v3::ReceiptType::Read => { - let mut user_receipts = BTreeMap::new(); - user_receipts.insert( - sender_user.clone(), - ruma::events::receipt::Receipt { - ts: Some(MilliSecondsSinceUnixEpoch::now()), - thread: ReceiptThread::Unthreaded, - }, - ); - let mut receipts = BTreeMap::new(); - receipts.insert(ReceiptType::Read, user_receipts); - - let mut receipt_content = BTreeMap::new(); - receipt_content.insert(body.event_id.clone(), receipts); + let receipt_content = BTreeMap::from_iter([( + body.event_id.clone(), + BTreeMap::from_iter([( + ReceiptType::Read, + BTreeMap::from_iter([( + sender_user.to_owned(), + ruma::events::receipt::Receipt { + ts: Some(MilliSecondsSinceUnixEpoch::now()), + thread: ReceiptThread::Unthreaded, + }, + )]), + )]), + )]); services .rooms @@ -174,23 +169,23 @@ pub(crate) async fn create_receipt_route( .timeline .get_pdu_count(&body.event_id) .await - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Event does not exist."))?; + .map_err(|_| err!(Request(NotFound("Event not found."))))?; - let count = match count { - PduCount::Backfilled(_) => { - return Err(Error::BadRequest( - ErrorKind::InvalidParam, - "Read receipt is in backfilled timeline", - )) - }, - PduCount::Normal(c) => c, + let PduCount::Normal(count) = count else { + return Err!(Request(InvalidParam("Event is a backfilled PDU and cannot be marked as read."))); }; + services .rooms .read_receipt .private_read_set(&body.room_id, sender_user, count); }, - _ => return Err(Error::bad_database("Unsupported receipt type")), + _ => { + return Err!(Request(InvalidParam(warn!( + "Received unknown read receipt type: {}", + &body.receipt_type + )))) + }, } Ok(create_receipt::v3::Response {}) diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index 28ca1ea2..44572970 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -9,8 +9,8 @@ use conduit::{ at, err, error, extract_variant, is_equal_to, is_false, pdu::EventHash, result::{FlatOk, LogDebugErr}, - utils, utils::{ + self, future::OptionExt, math::ruma_from_u64, stream::{BroadbandExt, Tools, WidebandExt}, @@ -740,9 +740,28 @@ async fn load_joined_room( let (notification_count, highlight_count) = unread_notifications; device_list_updates.extend(device_updates); + + let last_privateread_update = services + .rooms + .read_receipt + .last_privateread_update(sender_user, room_id) + .await > since; + + let private_read_event = if last_privateread_update { + services + .rooms + .read_receipt + .private_read_get(room_id, sender_user) + .await + .ok() + } else { + None + }; + let edus: Vec> = receipt_events .into_values() .chain(typing_events.into_iter()) + .chain(private_read_event.into_iter()) .collect(); // Save the state after this sync so we can send the correct state diff next diff --git a/src/service/rooms/read_receipt/data.rs b/src/service/rooms/read_receipt/data.rs index 34639e27..9a1dba45 100644 --- a/src/service/rooms/read_receipt/data.rs +++ b/src/service/rooms/read_receipt/data.rs @@ -79,15 +79,15 @@ impl Data { .ignore_err() } - pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { + pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, pdu_count: u64) { let key = (room_id, user_id); let next_count = self.services.globals.next_count().unwrap(); - self.roomuserid_privateread.put(key, count); + self.roomuserid_privateread.put(key, pdu_count); self.roomuserid_lastprivatereadupdate.put(key, next_count); } - pub(super) async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result { + pub(super) async fn private_read_get_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { let key = (room_id, user_id); self.roomuserid_privateread.qry(&key).await.deserialized() } diff --git a/src/service/rooms/read_receipt/mod.rs b/src/service/rooms/read_receipt/mod.rs index e089d369..a3cd7098 100644 --- a/src/service/rooms/read_receipt/mod.rs +++ b/src/service/rooms/read_receipt/mod.rs @@ -2,19 +2,19 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{debug, Result}; -use futures::Stream; +use conduit::{debug, err, warn, PduCount, PduId, RawPduId, Result}; +use futures::{try_join, Stream, TryFutureExt}; use ruma::{ events::{ receipt::{ReceiptEvent, ReceiptEventContent}, AnySyncEphemeralRoomEvent, SyncEphemeralRoomEvent, }, serde::Raw, - OwnedUserId, RoomId, UserId, + OwnedEventId, OwnedUserId, RoomId, UserId, }; use self::data::{Data, ReceiptItem}; -use crate::{sending, Dep}; +use crate::{rooms, sending, Dep}; pub struct Service { services: Services, @@ -23,6 +23,8 @@ pub struct Service { struct Services { sending: Dep, + short: Dep, + timeline: Dep, } impl crate::Service for Service { @@ -30,6 +32,8 @@ impl crate::Service for Service { Ok(Arc::new(Self { services: Services { sending: args.depend::("sending"), + short: args.depend::("rooms::short"), + timeline: args.depend::("rooms::timeline"), }, db: Data::new(&args), })) @@ -49,6 +53,48 @@ impl Service { .expect("room flush failed"); } + /// Gets the latest private read receipt from the user in the room + pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result> { + let pdu_count = self + .private_read_get_count(room_id, user_id) + .map_err(|e| err!(Database(warn!("No private read receipt was set in {room_id}: {e}")))); + let shortroomid = self + .services + .short + .get_shortroomid(room_id) + .map_err(|e| err!(Database(warn!("Short room ID does not exist in database for {room_id}: {e}")))); + let (pdu_count, shortroomid) = try_join!(pdu_count, shortroomid)?; + + let shorteventid = PduCount::Normal(pdu_count); + let pdu_id: RawPduId = PduId { + shortroomid, + shorteventid, + } + .into(); + + let pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await?; + + let event_id: OwnedEventId = pdu.event_id.into(); + let receipt_content = BTreeMap::from_iter([( + event_id, + BTreeMap::from_iter([( + ruma::events::receipt::ReceiptType::ReadPrivate, + BTreeMap::from_iter([( + user_id, + ruma::events::receipt::Receipt { + ts: None, // TODO: start storing the timestamp so we can return one + thread: ruma::events::receipt::ReceiptThread::Unthreaded, + }, + )]), + )]), + )]); + //let receipt_json = Json + + let event = serde_json::value::to_raw_value(&receipt_content).expect("receipt_content created manually"); + + Ok(Raw::from_json(event)) + } + /// Returns an iterator over the most recent read_receipts in a room that /// happened after the event with id `since`. #[inline] @@ -59,21 +105,21 @@ impl Service { self.db.readreceipts_since(room_id, since) } - /// Sets a private read marker at `count`. + /// Sets a private read marker at PDU `count`. #[inline] #[tracing::instrument(skip(self), level = "debug")] pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) { self.db.private_read_set(room_id, user_id, count); } - /// Returns the private read marker. + /// Returns the private read marker PDU count. #[inline] #[tracing::instrument(skip(self), level = "debug")] - pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result { - self.db.private_read_get(room_id, user_id).await + pub async fn private_read_get_count(&self, room_id: &RoomId, user_id: &UserId) -> Result { + self.db.private_read_get_count(room_id, user_id).await } - /// Returns the count of the last typing update in this room. + /// Returns the PDU count of the last typing update in this room. #[inline] pub async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 { self.db.last_privateread_update(user_id, room_id).await