diff --git a/Cargo.lock b/Cargo.lock index 44856753..f729d3d4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -76,6 +76,9 @@ name = "arrayvec" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c02d123df017efcdfbd739ef81735b36c5ba83ec3c59c80a9d7ecc718f92e50" +dependencies = [ + "serde", +] [[package]] name = "as_variant" diff --git a/Cargo.toml b/Cargo.toml index 043790f8..3ac1556c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -27,6 +27,7 @@ name = "conduit" [workspace.dependencies.arrayvec] version = "0.7.4" +features = ["std", "serde"] [workspace.dependencies.const-str] version = "0.5.7" diff --git a/src/api/client/context.rs b/src/api/client/context.rs index 9bf0c467..5b492cb1 100644 --- a/src/api/client/context.rs +++ b/src/api/client/context.rs @@ -168,12 +168,12 @@ pub(crate) async fn get_context_route( start: events_before .last() - .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()) + .map_or_else(|| base_token.to_string(), |(count, _)| count.to_string()) .into(), end: events_after .last() - .map_or_else(|| base_token.stringify(), |(count, _)| count.stringify()) + .map_or_else(|| base_token.to_string(), |(count, _)| count.to_string()) .into(), events_before: events_before diff --git a/src/api/client/membership.rs b/src/api/client/membership.rs index c41e93fa..fa71c0c8 100644 --- a/src/api/client/membership.rs +++ b/src/api/client/membership.rs @@ -1376,15 +1376,12 @@ pub(crate) async fn invite_helper( ) .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Origin field is invalid."))?; - let pdu_id: Vec = services + let pdu_id = services .rooms .event_handler .handle_incoming_pdu(&origin, room_id, &event_id, value, true) .await? - .ok_or(Error::BadRequest( - ErrorKind::InvalidParam, - "Could not accept incoming PDU as timeline event.", - ))?; + .ok_or_else(|| err!(Request(InvalidParam("Could not accept incoming PDU as timeline event."))))?; services.sending.send_pdu_room(room_id, &pdu_id).await?; return Ok(()); diff --git a/src/api/client/message.rs b/src/api/client/message.rs index 4fc58d9f..cb261a7f 100644 --- a/src/api/client/message.rs +++ b/src/api/client/message.rs @@ -62,19 +62,17 @@ pub(crate) async fn get_message_events_route( let room_id = &body.room_id; let filter = &body.filter; - let from_default = match body.dir { - Direction::Forward => PduCount::min(), - Direction::Backward => PduCount::max(), - }; - - let from = body + let from: PduCount = body .from .as_deref() - .map(PduCount::try_from_string) + .map(str::parse) .transpose()? - .unwrap_or(from_default); + .unwrap_or_else(|| match body.dir { + Direction::Forward => PduCount::min(), + Direction::Backward => PduCount::max(), + }); - let to = body.to.as_deref().map(PduCount::try_from_string).flat_ok(); + let to: Option = body.to.as_deref().map(str::parse).flat_ok(); let limit: usize = body .limit @@ -156,8 +154,8 @@ pub(crate) async fn get_message_events_route( .collect(); Ok(get_message_events::v3::Response { - start: from.stringify(), - end: next_token.as_ref().map(PduCount::stringify), + start: from.to_string(), + end: next_token.as_ref().map(PduCount::to_string), chunk, state, }) diff --git a/src/api/client/relations.rs b/src/api/client/relations.rs index 0456924c..ef7035e2 100644 --- a/src/api/client/relations.rs +++ b/src/api/client/relations.rs @@ -1,34 +1,43 @@ use axum::extract::State; -use ruma::api::client::relations::{ - get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, +use conduit::{ + at, + utils::{result::FlatOk, IterStream, ReadyExt}, + PduCount, Result, }; +use futures::{FutureExt, StreamExt}; +use ruma::{ + api::{ + client::relations::{ + get_relating_events, get_relating_events_with_rel_type, get_relating_events_with_rel_type_and_event_type, + }, + Direction, + }, + events::{relation::RelationType, TimelineEventType}, + EventId, RoomId, UInt, UserId, +}; +use service::{rooms::timeline::PdusIterItem, Services}; -use crate::{Result, Ruma}; +use crate::Ruma; /// # `GET /_matrix/client/r0/rooms/{roomId}/relations/{eventId}/{relType}/{eventType}` pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - - let res = services - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - body.event_type.clone().into(), - body.rel_type.clone().into(), - body.from.as_deref(), - body.to.as_deref(), - body.limit, - body.recurse, - body.dir, - ) - .await?; - - Ok(get_relating_events_with_rel_type_and_event_type::v1::Response { + paginate_relations_with_filter( + &services, + body.sender_user(), + &body.room_id, + &body.event_id, + body.event_type.clone().into(), + body.rel_type.clone().into(), + body.from.as_deref(), + body.to.as_deref(), + body.limit, + body.recurse, + body.dir, + ) + .await + .map(|res| get_relating_events_with_rel_type_and_event_type::v1::Response { chunk: res.chunk, next_batch: res.next_batch, prev_batch: res.prev_batch, @@ -40,26 +49,21 @@ pub(crate) async fn get_relating_events_with_rel_type_and_event_type_route( pub(crate) async fn get_relating_events_with_rel_type_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_deref().expect("user is authenticated"); - - let res = services - .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - None, - body.rel_type.clone().into(), - body.from.as_deref(), - body.to.as_deref(), - body.limit, - body.recurse, - body.dir, - ) - .await?; - - Ok(get_relating_events_with_rel_type::v1::Response { + paginate_relations_with_filter( + &services, + body.sender_user(), + &body.room_id, + &body.event_id, + None, + body.rel_type.clone().into(), + body.from.as_deref(), + body.to.as_deref(), + body.limit, + body.recurse, + body.dir, + ) + .await + .map(|res| get_relating_events_with_rel_type::v1::Response { chunk: res.chunk, next_batch: res.next_batch, prev_batch: res.prev_batch, @@ -71,22 +75,103 @@ pub(crate) async fn get_relating_events_with_rel_type_route( pub(crate) async fn get_relating_events_route( State(services): State, body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_deref().expect("user is authenticated"); + paginate_relations_with_filter( + &services, + body.sender_user(), + &body.room_id, + &body.event_id, + None, + None, + body.from.as_deref(), + body.to.as_deref(), + body.limit, + body.recurse, + body.dir, + ) + .await +} + +#[allow(clippy::too_many_arguments)] +async fn paginate_relations_with_filter( + services: &Services, sender_user: &UserId, room_id: &RoomId, target: &EventId, + filter_event_type: Option, filter_rel_type: Option, from: Option<&str>, + to: Option<&str>, limit: Option, recurse: bool, dir: Direction, +) -> Result { + let from: PduCount = from + .map(str::parse) + .transpose()? + .unwrap_or_else(|| match dir { + Direction::Forward => PduCount::min(), + Direction::Backward => PduCount::max(), + }); + + let to: Option = to.map(str::parse).flat_ok(); + + // Use limit or else 30, with maximum 100 + let limit: usize = limit + .map(TryInto::try_into) + .flat_ok() + .unwrap_or(30) + .min(100); + + // Spec (v1.10) recommends depth of at least 3 + let depth: u8 = if recurse { + 3 + } else { + 1 + }; + + let events: Vec = services + .rooms + .pdu_metadata + .get_relations(sender_user, room_id, target, from, limit, depth, dir) + .await + .into_iter() + .filter(|(_, pdu)| { + filter_event_type + .as_ref() + .is_none_or(|kind| *kind == pdu.kind) + }) + .filter(|(_, pdu)| { + filter_rel_type + .as_ref() + .is_none_or(|rel_type| pdu.relation_type_equal(rel_type)) + }) + .stream() + .filter_map(|item| visibility_filter(services, sender_user, item)) + .ready_take_while(|(count, _)| Some(*count) != to) + .take(limit) + .collect() + .boxed() + .await; + + let next_batch = match dir { + Direction::Backward => events.first(), + Direction::Forward => events.last(), + } + .map(at!(0)) + .as_ref() + .map(ToString::to_string); + + Ok(get_relating_events::v1::Response { + next_batch, + prev_batch: Some(from.to_string()), + recursion_depth: recurse.then_some(depth.into()), + chunk: events + .into_iter() + .map(at!(1)) + .map(|pdu| pdu.to_message_like_event()) + .collect(), + }) +} + +async fn visibility_filter(services: &Services, sender_user: &UserId, item: PdusIterItem) -> Option { + let (_, pdu) = &item; services .rooms - .pdu_metadata - .paginate_relations_with_filter( - sender_user, - &body.room_id, - &body.event_id, - None, - None, - body.from.as_deref(), - body.to.as_deref(), - body.limit, - body.recurse, - body.dir, - ) + .state_accessor + .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) .await + .then_some(item) } diff --git a/src/api/client/sync/mod.rs b/src/api/client/sync/mod.rs index ed22010c..7aec7186 100644 --- a/src/api/client/sync/mod.rs +++ b/src/api/client/sync/mod.rs @@ -1,10 +1,7 @@ mod v3; mod v4; -use conduit::{ - utils::{math::usize_from_u64_truncated, ReadyExt}, - PduCount, -}; +use conduit::{utils::ReadyExt, PduCount}; use futures::StreamExt; use ruma::{RoomId, UserId}; @@ -12,7 +9,7 @@ pub(crate) use self::{v3::sync_events_route, v4::sync_events_v4_route}; use crate::{service::Services, Error, PduEvent, Result}; async fn load_timeline( - services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: u64, + services: &Services, sender_user: &UserId, room_id: &RoomId, roomsincecount: PduCount, limit: usize, ) -> Result<(Vec<(PduCount, PduEvent)>, bool), Error> { let last_timeline_count = services .rooms @@ -29,12 +26,12 @@ async fn load_timeline( .timeline .pdus_until(sender_user, room_id, PduCount::max()) .await? - .ready_take_while(|(pducount, _)| pducount > &roomsincecount); + .ready_take_while(|(pducount, _)| *pducount > roomsincecount); // Take the last events for the timeline let timeline_pdus: Vec<_> = non_timeline_pdus .by_ref() - .take(usize_from_u64_truncated(limit)) + .take(limit) .collect::>() .await .into_iter() diff --git a/src/api/client/sync/v3.rs b/src/api/client/sync/v3.rs index ccca1f85..08048902 100644 --- a/src/api/client/sync/v3.rs +++ b/src/api/client/sync/v3.rs @@ -432,28 +432,26 @@ async fn handle_left_room( left_state_ids.insert(leave_shortstatekey, left_event_id); - let mut i: u8 = 0; - for (key, id) in left_state_ids { - if full_state || since_state_ids.get(&key) != Some(&id) { - let (event_type, state_key) = services.rooms.short.get_statekey_from_short(key).await?; + for (shortstatekey, event_id) in left_state_ids { + if full_state || since_state_ids.get(&shortstatekey) != Some(&event_id) { + let (event_type, state_key) = services + .rooms + .short + .get_statekey_from_short(shortstatekey) + .await?; + // TODO: Delete "element_hacks" when this is resolved: https://github.com/vector-im/element-web/issues/22565 if !lazy_load_enabled - || event_type != StateEventType::RoomMember - || full_state - // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 - || (cfg!(feature = "element_hacks") && *sender_user == state_key) + || event_type != StateEventType::RoomMember + || full_state + || (cfg!(feature = "element_hacks") && *sender_user == state_key) { - let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { - error!("Pdu in state not found: {}", id); + let Ok(pdu) = services.rooms.timeline.get_pdu(&event_id).await else { + error!("Pdu in state not found: {event_id}"); continue; }; left_state_events.push(pdu.to_sync_state_event()); - - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } } } } @@ -542,7 +540,7 @@ async fn load_joined_room( let insert_lock = services.rooms.timeline.mutex_insert.lock(room_id).await; drop(insert_lock); - let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10).await?; + let (timeline_pdus, limited) = load_timeline(services, sender_user, room_id, sincecount, 10_usize).await?; let send_notification_counts = !timeline_pdus.is_empty() || services @@ -678,8 +676,7 @@ async fn load_joined_room( let mut state_events = Vec::new(); let mut lazy_loaded = HashSet::new(); - let mut i: u8 = 0; - for (shortstatekey, id) in current_state_ids { + for (shortstatekey, event_id) in current_state_ids { let (event_type, state_key) = services .rooms .short @@ -687,24 +684,22 @@ async fn load_joined_room( .await?; if event_type != StateEventType::RoomMember { - let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { - error!("Pdu in state not found: {id}"); + let Ok(pdu) = services.rooms.timeline.get_pdu(&event_id).await else { + error!("Pdu in state not found: {event_id}"); continue; }; - state_events.push(pdu); - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } - } else if !lazy_load_enabled - || full_state - || timeline_users.contains(&state_key) - // TODO: Delete the following line when this is resolved: https://github.com/vector-im/element-web/issues/22565 - || (cfg!(feature = "element_hacks") && *sender_user == state_key) + state_events.push(pdu); + continue; + } + + // TODO: Delete "element_hacks" when this is resolved: https://github.com/vector-im/element-web/issues/22565 + if !lazy_load_enabled + || full_state || timeline_users.contains(&state_key) + || (cfg!(feature = "element_hacks") && *sender_user == state_key) { - let Ok(pdu) = services.rooms.timeline.get_pdu(&id).await else { - error!("Pdu in state not found: {id}"); + let Ok(pdu) = services.rooms.timeline.get_pdu(&event_id).await else { + error!("Pdu in state not found: {event_id}"); continue; }; @@ -712,12 +707,8 @@ async fn load_joined_room( if let Ok(uid) = UserId::parse(&state_key) { lazy_loaded.insert(uid); } - state_events.push(pdu); - i = i.wrapping_add(1); - if i % 100 == 0 { - tokio::task::yield_now().await; - } + state_events.push(pdu); } } diff --git a/src/api/client/sync/v4.rs b/src/api/client/sync/v4.rs index f8ada81c..11e3830c 100644 --- a/src/api/client/sync/v4.rs +++ b/src/api/client/sync/v4.rs @@ -8,7 +8,7 @@ use axum::extract::State; use conduit::{ debug, error, extract_variant, utils::{ - math::{ruma_from_usize, usize_from_ruma}, + math::{ruma_from_usize, usize_from_ruma, usize_from_u64_truncated}, BoolExt, IterStream, ReadyExt, TryFutureExtExt, }, warn, Error, PduCount, Result, @@ -350,14 +350,16 @@ pub(crate) async fn sync_events_v4_route( new_known_rooms.extend(room_ids.iter().cloned()); for room_id in &room_ids { - let todo_room = todo_rooms - .entry(room_id.clone()) - .or_insert((BTreeSet::new(), 0, u64::MAX)); + let todo_room = + todo_rooms + .entry(room_id.clone()) + .or_insert((BTreeSet::new(), 0_usize, u64::MAX)); - let limit = list + let limit: usize = list .room_details .timeline_limit - .map_or(10, u64::from) + .map(u64::from) + .map_or(10, usize_from_u64_truncated) .min(100); todo_room @@ -406,8 +408,14 @@ pub(crate) async fn sync_events_v4_route( } let todo_room = todo_rooms .entry(room_id.clone()) - .or_insert((BTreeSet::new(), 0, u64::MAX)); - let limit = room.timeline_limit.map_or(10, u64::from).min(100); + .or_insert((BTreeSet::new(), 0_usize, u64::MAX)); + + let limit: usize = room + .timeline_limit + .map(u64::from) + .map_or(10, usize_from_u64_truncated) + .min(100); + todo_room.0.extend(room.required_state.iter().cloned()); todo_room.1 = todo_room.1.max(limit); // 0 means unknown because it got out of date diff --git a/src/api/client/threads.rs b/src/api/client/threads.rs index 50f6cdfb..02cf7992 100644 --- a/src/api/client/threads.rs +++ b/src/api/client/threads.rs @@ -1,19 +1,14 @@ use axum::extract::State; -use conduit::PduEvent; +use conduit::{PduCount, PduEvent}; use futures::StreamExt; -use ruma::{ - api::client::{error::ErrorKind, threads::get_threads}, - uint, -}; +use ruma::{api::client::threads::get_threads, uint}; -use crate::{Error, Result, Ruma}; +use crate::{Result, Ruma}; /// # `GET /_matrix/client/r0/rooms/{roomId}/threads` pub(crate) async fn get_threads_route( - State(services): State, body: Ruma, + State(services): State, ref body: Ruma, ) -> Result { - let sender_user = body.sender_user.as_ref().expect("user is authenticated"); - // Use limit or else 10, with maximum 100 let limit = body .limit @@ -22,38 +17,39 @@ pub(crate) async fn get_threads_route( .unwrap_or(10) .min(100); - let from = if let Some(from) = &body.from { - from.parse() - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, ""))? - } else { - u64::MAX - }; + let from: PduCount = body + .from + .as_deref() + .map(str::parse) + .transpose()? + .unwrap_or_else(PduCount::max); - let room_id = &body.room_id; - let threads: Vec<(u64, PduEvent)> = services + let threads: Vec<(PduCount, PduEvent)> = services .rooms .threads - .threads_until(sender_user, &body.room_id, from, &body.include) + .threads_until(body.sender_user(), &body.room_id, from, &body.include) .await? .take(limit) .filter_map(|(count, pdu)| async move { services .rooms .state_accessor - .user_can_see_event(sender_user, room_id, &pdu.event_id) + .user_can_see_event(body.sender_user(), &body.room_id, &pdu.event_id) .await .then_some((count, pdu)) }) .collect() .await; - let next_batch = threads.last().map(|(count, _)| count.to_string()); - Ok(get_threads::v1::Response { + next_batch: threads + .last() + .map(|(count, _)| count) + .map(ToString::to_string), + chunk: threads .into_iter() .map(|(_, pdu)| pdu.to_room_event()) .collect(), - next_batch, }) } diff --git a/src/api/server/send_join.rs b/src/api/server/send_join.rs index c3273baf..f2ede9d0 100644 --- a/src/api/server/send_join.rs +++ b/src/api/server/send_join.rs @@ -156,12 +156,12 @@ async fn create_join_event( .lock(room_id) .await; - let pdu_id: Vec = services + let pdu_id = services .rooms .event_handler .handle_incoming_pdu(&origin, room_id, &event_id, value.clone(), true) .await? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; + .ok_or_else(|| err!(Request(InvalidParam("Could not accept as timeline event."))))?; drop(mutex_lock); diff --git a/src/api/server/send_leave.rs b/src/api/server/send_leave.rs index 7b4a8aee..448e5de3 100644 --- a/src/api/server/send_leave.rs +++ b/src/api/server/send_leave.rs @@ -1,7 +1,7 @@ #![allow(deprecated)] use axum::extract::State; -use conduit::{utils::ReadyExt, Error, Result}; +use conduit::{err, utils::ReadyExt, Error, Result}; use ruma::{ api::{client::error::ErrorKind, federation::membership::create_leave_event}, events::{ @@ -142,12 +142,12 @@ async fn create_leave_event( .lock(room_id) .await; - let pdu_id: Vec = services + let pdu_id = services .rooms .event_handler .handle_incoming_pdu(origin, room_id, &event_id, value, true) .await? - .ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Could not accept as timeline event."))?; + .ok_or_else(|| err!(Request(InvalidParam("Could not accept as timeline event."))))?; drop(mutex_lock); diff --git a/src/core/mod.rs b/src/core/mod.rs index 1b7b8fa1..4ab84730 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -17,7 +17,7 @@ pub use ::tracing; pub use config::Config; pub use error::Error; pub use info::{rustc_flags_capture, version, version::version}; -pub use pdu::{Event, PduBuilder, PduCount, PduEvent}; +pub use pdu::{Event, PduBuilder, PduCount, PduEvent, PduId, RawPduId}; pub use server::Server; pub use utils::{ctor, dtor, implement, result, result::Result}; diff --git a/src/core/pdu/count.rs b/src/core/pdu/count.rs index 094988b6..90e552e8 100644 --- a/src/core/pdu/count.rs +++ b/src/core/pdu/count.rs @@ -1,38 +1,135 @@ -use std::cmp::Ordering; +#![allow(clippy::cast_possible_wrap, clippy::cast_sign_loss, clippy::as_conversions)] -use ruma::api::client::error::ErrorKind; +use std::{cmp::Ordering, fmt, fmt::Display, str::FromStr}; -use crate::{Error, Result}; +use crate::{err, Error, Result}; #[derive(Hash, PartialEq, Eq, Clone, Copy, Debug)] pub enum PduCount { - Backfilled(u64), Normal(u64), + Backfilled(i64), } impl PduCount { + #[inline] #[must_use] - pub fn min() -> Self { Self::Backfilled(u64::MAX) } + pub fn from_unsigned(unsigned: u64) -> Self { Self::from_signed(unsigned as i64) } + #[inline] #[must_use] - pub fn max() -> Self { Self::Normal(u64::MAX) } - - pub fn try_from_string(token: &str) -> Result { - if let Some(stripped_token) = token.strip_prefix('-') { - stripped_token.parse().map(PduCount::Backfilled) - } else { - token.parse().map(PduCount::Normal) + pub fn from_signed(signed: i64) -> Self { + match signed { + i64::MIN..=0 => Self::Backfilled(signed), + _ => Self::Normal(signed as u64), } - .map_err(|_| Error::BadRequest(ErrorKind::InvalidParam, "Invalid pagination token.")) } + #[inline] #[must_use] - pub fn stringify(&self) -> String { + pub fn into_unsigned(self) -> u64 { + self.debug_assert_valid(); match self { - Self::Backfilled(x) => format!("-{x}"), - Self::Normal(x) => x.to_string(), + Self::Normal(i) => i, + Self::Backfilled(i) => i as u64, } } + + #[inline] + #[must_use] + pub fn into_signed(self) -> i64 { + self.debug_assert_valid(); + match self { + Self::Normal(i) => i as i64, + Self::Backfilled(i) => i, + } + } + + #[inline] + #[must_use] + pub fn into_normal(self) -> Self { + self.debug_assert_valid(); + match self { + Self::Normal(i) => Self::Normal(i), + Self::Backfilled(_) => Self::Normal(0), + } + } + + #[inline] + pub fn checked_add(self, add: u64) -> Result { + Ok(match self { + Self::Normal(i) => Self::Normal( + i.checked_add(add) + .ok_or_else(|| err!(Arithmetic("PduCount::Normal overflow")))?, + ), + Self::Backfilled(i) => Self::Backfilled( + i.checked_add(add as i64) + .ok_or_else(|| err!(Arithmetic("PduCount::Backfilled overflow")))?, + ), + }) + } + + #[inline] + pub fn checked_sub(self, sub: u64) -> Result { + Ok(match self { + Self::Normal(i) => Self::Normal( + i.checked_sub(sub) + .ok_or_else(|| err!(Arithmetic("PduCount::Normal underflow")))?, + ), + Self::Backfilled(i) => Self::Backfilled( + i.checked_sub(sub as i64) + .ok_or_else(|| err!(Arithmetic("PduCount::Backfilled underflow")))?, + ), + }) + } + + #[inline] + #[must_use] + pub fn saturating_add(self, add: u64) -> Self { + match self { + Self::Normal(i) => Self::Normal(i.saturating_add(add)), + Self::Backfilled(i) => Self::Backfilled(i.saturating_add(add as i64)), + } + } + + #[inline] + #[must_use] + pub fn saturating_sub(self, sub: u64) -> Self { + match self { + Self::Normal(i) => Self::Normal(i.saturating_sub(sub)), + Self::Backfilled(i) => Self::Backfilled(i.saturating_sub(sub as i64)), + } + } + + #[inline] + #[must_use] + pub fn min() -> Self { Self::Backfilled(i64::MIN) } + + #[inline] + #[must_use] + pub fn max() -> Self { Self::Normal(i64::MAX as u64) } + + #[inline] + pub(crate) fn debug_assert_valid(&self) { + if let Self::Backfilled(i) = self { + debug_assert!(*i <= 0, "Backfilled sequence must be negative"); + } + } +} + +impl Display for PduCount { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> Result<(), fmt::Error> { + self.debug_assert_valid(); + match self { + Self::Normal(i) => write!(f, "{i}"), + Self::Backfilled(i) => write!(f, "{i}"), + } + } +} + +impl FromStr for PduCount { + type Err = Error; + + fn from_str(token: &str) -> Result { Ok(Self::from_signed(token.parse()?)) } } impl PartialOrd for PduCount { @@ -40,12 +137,9 @@ impl PartialOrd for PduCount { } impl Ord for PduCount { - fn cmp(&self, other: &Self) -> Ordering { - match (self, other) { - (Self::Normal(s), Self::Normal(o)) => s.cmp(o), - (Self::Backfilled(s), Self::Backfilled(o)) => o.cmp(s), - (Self::Normal(_), Self::Backfilled(_)) => Ordering::Greater, - (Self::Backfilled(_), Self::Normal(_)) => Ordering::Less, - } - } + fn cmp(&self, other: &Self) -> Ordering { self.into_signed().cmp(&other.into_signed()) } +} + +impl Default for PduCount { + fn default() -> Self { Self::Normal(0) } } diff --git a/src/core/pdu/id.rs b/src/core/pdu/id.rs new file mode 100644 index 00000000..05d11904 --- /dev/null +++ b/src/core/pdu/id.rs @@ -0,0 +1,22 @@ +use super::{PduCount, RawPduId}; +use crate::utils::u64_from_u8x8; + +pub type ShortRoomId = ShortId; +pub type ShortEventId = ShortId; +pub type ShortId = u64; + +#[derive(Clone, Copy, Debug, Eq, PartialEq)] +pub struct PduId { + pub shortroomid: ShortRoomId, + pub shorteventid: PduCount, +} + +impl From for PduId { + #[inline] + fn from(raw: RawPduId) -> Self { + Self { + shortroomid: u64_from_u8x8(raw.shortroomid()), + shorteventid: PduCount::from_unsigned(u64_from_u8x8(raw.shorteventid())), + } + } +} diff --git a/src/core/pdu/mod.rs b/src/core/pdu/mod.rs index 53fcd0a9..c785c99e 100644 --- a/src/core/pdu/mod.rs +++ b/src/core/pdu/mod.rs @@ -4,8 +4,12 @@ mod count; mod event; mod event_id; mod filter; +mod id; +mod raw_id; mod redact; +mod relation; mod strip; +mod tests; mod unsigned; use std::{cmp::Ordering, sync::Arc}; @@ -21,6 +25,8 @@ pub use self::{ count::PduCount, event::Event, event_id::*, + id::*, + raw_id::*, }; use crate::Result; diff --git a/src/core/pdu/raw_id.rs b/src/core/pdu/raw_id.rs new file mode 100644 index 00000000..faba1cbf --- /dev/null +++ b/src/core/pdu/raw_id.rs @@ -0,0 +1,117 @@ +use arrayvec::ArrayVec; + +use super::{PduCount, PduId, ShortEventId, ShortId, ShortRoomId}; + +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +pub enum RawPduId { + Normal(RawPduIdNormal), + Backfilled(RawPduIdBackfilled), +} + +type RawPduIdNormal = [u8; RawPduId::NORMAL_LEN]; +type RawPduIdBackfilled = [u8; RawPduId::BACKFILLED_LEN]; + +const INT_LEN: usize = size_of::(); + +impl RawPduId { + const BACKFILLED_LEN: usize = size_of::() + INT_LEN + size_of::(); + const MAX_LEN: usize = Self::BACKFILLED_LEN; + const NORMAL_LEN: usize = size_of::() + size_of::(); + + #[inline] + #[must_use] + pub fn pdu_count(&self) -> PduCount { + let id: PduId = (*self).into(); + id.shorteventid + } + + #[inline] + #[must_use] + pub fn shortroomid(self) -> [u8; INT_LEN] { + match self { + Self::Normal(raw) => raw[0..INT_LEN] + .try_into() + .expect("normal raw shortroomid array from slice"), + Self::Backfilled(raw) => raw[0..INT_LEN] + .try_into() + .expect("backfilled raw shortroomid array from slice"), + } + } + + #[inline] + #[must_use] + pub fn shorteventid(self) -> [u8; INT_LEN] { + match self { + Self::Normal(raw) => raw[INT_LEN..INT_LEN * 2] + .try_into() + .expect("normal raw shorteventid array from slice"), + Self::Backfilled(raw) => raw[INT_LEN * 2..INT_LEN * 3] + .try_into() + .expect("backfilled raw shorteventid array from slice"), + } + } + + #[inline] + #[must_use] + pub fn as_bytes(&self) -> &[u8] { + match self { + Self::Normal(ref raw) => raw, + Self::Backfilled(ref raw) => raw, + } + } +} + +impl AsRef<[u8]> for RawPduId { + #[inline] + fn as_ref(&self) -> &[u8] { self.as_bytes() } +} + +impl From<&[u8]> for RawPduId { + #[inline] + fn from(id: &[u8]) -> Self { + match id.len() { + Self::NORMAL_LEN => Self::Normal( + id[0..Self::NORMAL_LEN] + .try_into() + .expect("normal RawPduId from [u8]"), + ), + Self::BACKFILLED_LEN => Self::Backfilled( + id[0..Self::BACKFILLED_LEN] + .try_into() + .expect("backfilled RawPduId from [u8]"), + ), + _ => unimplemented!("unrecognized RawPduId length"), + } + } +} + +impl From for RawPduId { + #[inline] + fn from(id: PduId) -> Self { + const MAX_LEN: usize = RawPduId::MAX_LEN; + type RawVec = ArrayVec; + + let mut vec = RawVec::new(); + vec.extend(id.shortroomid.to_be_bytes()); + id.shorteventid.debug_assert_valid(); + match id.shorteventid { + PduCount::Normal(shorteventid) => { + vec.extend(shorteventid.to_be_bytes()); + Self::Normal( + vec.as_ref() + .try_into() + .expect("RawVec into RawPduId::Normal"), + ) + }, + PduCount::Backfilled(shorteventid) => { + vec.extend(0_u64.to_be_bytes()); + vec.extend(shorteventid.to_be_bytes()); + Self::Backfilled( + vec.as_ref() + .try_into() + .expect("RawVec into RawPduId::Backfilled"), + ) + }, + } + } +} diff --git a/src/core/pdu/relation.rs b/src/core/pdu/relation.rs new file mode 100644 index 00000000..ae156a3d --- /dev/null +++ b/src/core/pdu/relation.rs @@ -0,0 +1,22 @@ +use ruma::events::relation::RelationType; +use serde::Deserialize; + +use crate::implement; + +#[derive(Clone, Debug, Deserialize)] +struct ExtractRelType { + rel_type: RelationType, +} +#[derive(Clone, Debug, Deserialize)] +struct ExtractRelatesToEventId { + #[serde(rename = "m.relates_to")] + relates_to: ExtractRelType, +} + +#[implement(super::PduEvent)] +#[must_use] +pub fn relation_type_equal(&self, rel_type: &RelationType) -> bool { + self.get_content() + .map(|c: ExtractRelatesToEventId| c.relates_to.rel_type) + .is_ok_and(|r| r == *rel_type) +} diff --git a/src/core/pdu/tests.rs b/src/core/pdu/tests.rs new file mode 100644 index 00000000..30ec23ba --- /dev/null +++ b/src/core/pdu/tests.rs @@ -0,0 +1,19 @@ +#![cfg(test)] + +use super::PduCount; + +#[test] +fn backfilled_parse() { + let count: PduCount = "-987654".parse().expect("parse() failed"); + let backfilled = matches!(count, PduCount::Backfilled(_)); + + assert!(backfilled, "not backfilled variant"); +} + +#[test] +fn normal_parse() { + let count: PduCount = "987654".parse().expect("parse() failed"); + let backfilled = matches!(count, PduCount::Backfilled(_)); + + assert!(!backfilled, "backfilled variant"); +} diff --git a/src/service/migrations.rs b/src/service/migrations.rs index 45323fa2..d6c342f8 100644 --- a/src/service/migrations.rs +++ b/src/service/migrations.rs @@ -71,7 +71,7 @@ async fn fresh(services: &Services) -> Result<()> { db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []); // Create the admin room and server user on first run - crate::admin::create_admin_room(services).await?; + crate::admin::create_admin_room(services).boxed().await?; warn!( "Created new {} database with version {DATABASE_VERSION}", diff --git a/src/service/rooms/auth_chain/data.rs b/src/service/rooms/auth_chain/data.rs index 5c9dbda8..3c36928a 100644 --- a/src/service/rooms/auth_chain/data.rs +++ b/src/service/rooms/auth_chain/data.rs @@ -7,9 +7,11 @@ use conduit::{err, utils, utils::math::usize_from_f64, Err, Result}; use database::Map; use lru_cache::LruCache; +use crate::rooms::short::ShortEventId; + pub(super) struct Data { shorteventid_authchain: Arc, - pub(super) auth_chain_cache: Mutex, Arc<[u64]>>>, + pub(super) auth_chain_cache: Mutex, Arc<[ShortEventId]>>>, } impl Data { @@ -24,7 +26,7 @@ impl Data { } } - pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { + pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); // Check RAM cache @@ -63,7 +65,7 @@ impl Data { Ok(chain) } - pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[u64]>) { + pub(super) fn cache_auth_chain(&self, key: Vec, auth_chain: Arc<[ShortEventId]>) { debug_assert!(!key.is_empty(), "auth_chain key must not be empty"); // Only persist single events in db diff --git a/src/service/rooms/auth_chain/mod.rs b/src/service/rooms/auth_chain/mod.rs index 1387bc7d..c22732c2 100644 --- a/src/service/rooms/auth_chain/mod.rs +++ b/src/service/rooms/auth_chain/mod.rs @@ -10,7 +10,7 @@ use futures::Stream; use ruma::{EventId, RoomId}; use self::data::Data; -use crate::{rooms, Dep}; +use crate::{rooms, rooms::short::ShortEventId, Dep}; pub struct Service { services: Services, @@ -64,7 +64,7 @@ impl Service { } #[tracing::instrument(skip_all, name = "auth_chain")] - pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result> { + pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result> { const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db? const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new(); @@ -97,7 +97,7 @@ impl Service { continue; } - let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); + let chunk_key: Vec = chunk.iter().map(|(short, _)| short).copied().collect(); if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await { trace!("Found cache entry for whole chunk"); full_auth_chain.extend(cached.iter().copied()); @@ -156,7 +156,7 @@ impl Service { } #[tracing::instrument(skip(self, room_id))] - async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { + async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result> { let mut todo = vec![Arc::from(event_id)]; let mut found = HashSet::new(); @@ -195,19 +195,19 @@ impl Service { } #[inline] - pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { + pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result> { self.db.get_cached_eventid_authchain(key).await } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) { - let val = auth_chain.iter().copied().collect::>(); + pub fn cache_auth_chain(&self, key: Vec, auth_chain: &HashSet) { + let val = auth_chain.iter().copied().collect::>(); self.db.cache_auth_chain(key, val); } #[tracing::instrument(skip(self), level = "debug")] - pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) { - let val = auth_chain.iter().copied().collect::>(); + pub fn cache_auth_chain_vec(&self, key: Vec, auth_chain: &Vec) { + let val = auth_chain.iter().copied().collect::>(); self.db.cache_auth_chain(key, val); } diff --git a/src/service/rooms/event_handler/mod.rs b/src/service/rooms/event_handler/mod.rs index adebd332..f76f817d 100644 --- a/src/service/rooms/event_handler/mod.rs +++ b/src/service/rooms/event_handler/mod.rs @@ -35,7 +35,10 @@ use ruma::{ use crate::{ globals, rooms, - rooms::state_compressor::{CompressedStateEvent, HashSetCompressStateEvent}, + rooms::{ + state_compressor::{CompressedStateEvent, HashSetCompressStateEvent}, + timeline::RawPduId, + }, sending, server_keys, Dep, }; @@ -136,10 +139,10 @@ impl Service { pub async fn handle_incoming_pdu<'a>( &self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId, value: BTreeMap, is_timeline_event: bool, - ) -> Result>> { + ) -> Result> { // 1. Skip the PDU if we already have it as a timeline event if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await { - return Ok(Some(pdu_id.to_vec())); + return Ok(Some(pdu_id)); } // 1.1 Check the server is in the room @@ -488,7 +491,7 @@ impl Service { pub async fn upgrade_outlier_to_timeline_pdu( &self, incoming_pdu: Arc, val: BTreeMap, create_event: &PduEvent, origin: &ServerName, room_id: &RoomId, - ) -> Result>> { + ) -> Result> { // Skip the PDU if we already have it as a timeline event if let Ok(pduid) = self .services @@ -496,7 +499,7 @@ impl Service { .get_pdu_id(&incoming_pdu.event_id) .await { - return Ok(Some(pduid.to_vec())); + return Ok(Some(pduid)); } if self diff --git a/src/service/rooms/pdu_metadata/data.rs b/src/service/rooms/pdu_metadata/data.rs index 51a43714..3fc06591 100644 --- a/src/service/rooms/pdu_metadata/data.rs +++ b/src/service/rooms/pdu_metadata/data.rs @@ -2,15 +2,21 @@ use std::{mem::size_of, sync::Arc}; use conduit::{ result::LogErr, - utils, - utils::{stream::TryIgnore, ReadyExt}, + utils::{stream::TryIgnore, u64_from_u8, ReadyExt}, PduCount, PduEvent, }; use database::Map; use futures::{Stream, StreamExt}; use ruma::{api::Direction, EventId, RoomId, UserId}; -use crate::{rooms, Dep}; +use crate::{ + rooms, + rooms::{ + short::{ShortEventId, ShortRoomId}, + timeline::{PduId, RawPduId}, + }, + Dep, +}; pub(super) struct Data { tofrom_relation: Arc, @@ -46,35 +52,36 @@ impl Data { } pub(super) fn get_relations<'a>( - &'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, dir: Direction, + &'a self, user_id: &'a UserId, shortroomid: ShortRoomId, target: ShortEventId, from: PduCount, dir: Direction, ) -> impl Stream + Send + '_ { - let prefix = target.to_be_bytes().to_vec(); - let mut current = prefix.clone(); - let count_raw = match until { - PduCount::Normal(x) => x.saturating_sub(1), - PduCount::Backfilled(x) => { - current.extend_from_slice(&0_u64.to_be_bytes()); - u64::MAX.saturating_sub(x).saturating_sub(1) - }, - }; - current.extend_from_slice(&count_raw.to_be_bytes()); + let current: RawPduId = PduId { + shortroomid, + shorteventid: from, + } + .into(); match dir { Direction::Forward => self.tofrom_relation.raw_keys_from(¤t).boxed(), Direction::Backward => self.tofrom_relation.rev_raw_keys_from(¤t).boxed(), } .ignore_err() - .ready_take_while(move |key| key.starts_with(&prefix)) - .map(|to_from| utils::u64_from_u8(&to_from[(size_of::())..])) - .filter_map(move |from| async move { - let mut pduid = shortroomid.to_be_bytes().to_vec(); - pduid.extend_from_slice(&from.to_be_bytes()); - let mut pdu = self.services.timeline.get_pdu_from_id(&pduid).await.ok()?; + .ready_take_while(move |key| key.starts_with(&target.to_be_bytes())) + .map(|to_from| u64_from_u8(&to_from[8..16])) + .map(PduCount::from_unsigned) + .filter_map(move |shorteventid| async move { + let pdu_id: RawPduId = PduId { + shortroomid, + shorteventid, + } + .into(); + + let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; + if pdu.sender != user_id { pdu.remove_transaction_id().log_err().ok(); } - Some((PduCount::Normal(from), pdu)) + Some((shorteventid, pdu)) }) } diff --git a/src/service/rooms/pdu_metadata/mod.rs b/src/service/rooms/pdu_metadata/mod.rs index b1cf2049..82d2ee35 100644 --- a/src/service/rooms/pdu_metadata/mod.rs +++ b/src/service/rooms/pdu_metadata/mod.rs @@ -1,18 +1,9 @@ mod data; use std::sync::Arc; -use conduit::{ - at, - utils::{result::FlatOk, stream::ReadyExt, IterStream}, - PduCount, Result, -}; -use futures::{FutureExt, StreamExt}; -use ruma::{ - api::{client::relations::get_relating_events, Direction}, - events::{relation::RelationType, TimelineEventType}, - EventId, RoomId, UInt, UserId, -}; -use serde::Deserialize; +use conduit::{PduCount, Result}; +use futures::StreamExt; +use ruma::{api::Direction, EventId, RoomId, UserId}; use self::data::{Data, PdusIterItem}; use crate::{rooms, Dep}; @@ -24,26 +15,14 @@ pub struct Service { struct Services { short: Dep, - state_accessor: Dep, timeline: Dep, } -#[derive(Clone, Debug, Deserialize)] -struct ExtractRelType { - rel_type: RelationType, -} -#[derive(Clone, Debug, Deserialize)] -struct ExtractRelatesToEventId { - #[serde(rename = "m.relates_to")] - relates_to: ExtractRelType, -} - impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { Ok(Arc::new(Self { services: Services { short: args.depend::("rooms::short"), - state_accessor: args.depend::("rooms::state_accessor"), timeline: args.depend::("rooms::timeline"), }, db: Data::new(&args), @@ -64,82 +43,9 @@ impl Service { } } - #[allow(clippy::too_many_arguments)] - pub async fn paginate_relations_with_filter( - &self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: Option, - filter_rel_type: Option, from: Option<&str>, to: Option<&str>, limit: Option, - recurse: bool, dir: Direction, - ) -> Result { - let from = from - .map(PduCount::try_from_string) - .transpose()? - .unwrap_or_else(|| match dir { - Direction::Forward => PduCount::min(), - Direction::Backward => PduCount::max(), - }); - - let to = to.map(PduCount::try_from_string).flat_ok(); - - // Use limit or else 30, with maximum 100 - let limit: usize = limit - .map(TryInto::try_into) - .flat_ok() - .unwrap_or(30) - .min(100); - - // Spec (v1.10) recommends depth of at least 3 - let depth: u8 = if recurse { - 3 - } else { - 1 - }; - - let events: Vec = self - .get_relations(sender_user, room_id, target, from, limit, depth, dir) - .await - .into_iter() - .filter(|(_, pdu)| { - filter_event_type - .as_ref() - .is_none_or(|kind| *kind == pdu.kind) - }) - .filter(|(_, pdu)| { - filter_rel_type.as_ref().is_none_or(|rel_type| { - pdu.get_content() - .map(|c: ExtractRelatesToEventId| c.relates_to.rel_type) - .is_ok_and(|r| r == *rel_type) - }) - }) - .stream() - .filter_map(|item| self.visibility_filter(sender_user, item)) - .ready_take_while(|(count, _)| Some(*count) != to) - .take(limit) - .collect() - .boxed() - .await; - - let next_batch = match dir { - Direction::Backward => events.first(), - Direction::Forward => events.last(), - } - .map(at!(0)) - .map(|t| t.stringify()); - - Ok(get_relating_events::v1::Response { - next_batch, - prev_batch: Some(from.stringify()), - recursion_depth: recurse.then_some(depth.into()), - chunk: events - .into_iter() - .map(at!(1)) - .map(|pdu| pdu.to_message_like_event()) - .collect(), - }) - } - #[allow(clippy::too_many_arguments)] pub async fn get_relations( - &self, user_id: &UserId, room_id: &RoomId, target: &EventId, until: PduCount, limit: usize, max_depth: u8, + &self, user_id: &UserId, room_id: &RoomId, target: &EventId, from: PduCount, limit: usize, max_depth: u8, dir: Direction, ) -> Vec { let room_id = self.services.short.get_or_create_shortroomid(room_id).await; @@ -152,7 +58,7 @@ impl Service { let mut pdus: Vec<_> = self .db - .get_relations(user_id, room_id, target, until, dir) + .get_relations(user_id, room_id, target, from, dir) .collect() .await; @@ -167,7 +73,7 @@ impl Service { let relations: Vec<_> = self .db - .get_relations(user_id, room_id, target, until, dir) + .get_relations(user_id, room_id, target, from, dir) .collect() .await; @@ -186,16 +92,6 @@ impl Service { pdus } - async fn visibility_filter(&self, sender_user: &UserId, item: PdusIterItem) -> Option { - let (_, pdu) = &item; - - self.services - .state_accessor - .user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id) - .await - .then_some(item) - } - #[inline] #[tracing::instrument(skip_all, level = "debug")] pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc]) { diff --git a/src/service/rooms/search/mod.rs b/src/service/rooms/search/mod.rs index 70daded1..1af37d9e 100644 --- a/src/service/rooms/search/mod.rs +++ b/src/service/rooms/search/mod.rs @@ -1,10 +1,10 @@ -use std::{iter, sync::Arc}; +use std::sync::Arc; use arrayvec::ArrayVec; use conduit::{ implement, utils::{set, stream::TryIgnore, ArrayVecExt, IterStream, ReadyExt}, - PduEvent, Result, + PduCount, PduEvent, Result, }; use database::{keyval::Val, Map}; use futures::{Stream, StreamExt}; @@ -66,13 +66,13 @@ impl crate::Service for Service { } #[implement(Service)] -pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { +pub fn index_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) { let batch = tokenize(message_body) .map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); key.push(0xFF); - key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here + key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here (key, Vec::::new()) }) .collect::>(); @@ -81,12 +81,12 @@ pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { } #[implement(Service)] -pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) { +pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) { let batch = tokenize(message_body).map(|word| { let mut key = shortroomid.to_be_bytes().to_vec(); key.extend_from_slice(word.as_bytes()); key.push(0xFF); - key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here + key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here key }); @@ -159,24 +159,24 @@ fn search_pdu_ids_query_words<'a>( &'a self, shortroomid: ShortRoomId, word: &'a str, ) -> impl Stream + Send + '_ { self.search_pdu_ids_query_word(shortroomid, word) - .ready_filter_map(move |key| { - key[prefix_len(word)..] - .chunks_exact(PduId::LEN) - .next() - .map(RawPduId::try_from) - .and_then(Result::ok) + .map(move |key| -> RawPduId { + let key = &key[prefix_len(word)..]; + key.into() }) } /// Iterate over raw database results for a word #[implement(Service)] fn search_pdu_ids_query_word(&self, shortroomid: ShortRoomId, word: &str) -> impl Stream> + Send + '_ { - const PDUID_LEN: usize = PduId::LEN; // rustc says const'ing this not yet stable - let end_id: ArrayVec = iter::repeat(u8::MAX).take(PduId::LEN).collect(); + let end_id: RawPduId = PduId { + shortroomid, + shorteventid: PduCount::max(), + } + .into(); // Newest pdus first - let end = make_tokenid(shortroomid, word, end_id.as_slice()); + let end = make_tokenid(shortroomid, word, &end_id); let prefix = make_prefix(shortroomid, word); self.db .tokenids @@ -196,11 +196,9 @@ fn tokenize(body: &str) -> impl Iterator + Send + '_ { .map(str::to_lowercase) } -fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &[u8]) -> TokenId { - debug_assert!(pdu_id.len() == PduId::LEN, "pdu_id size mismatch"); - +fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &RawPduId) -> TokenId { let mut key = make_prefix(shortroomid, word); - key.extend_from_slice(pdu_id); + key.extend_from_slice(pdu_id.as_ref()); key } diff --git a/src/service/rooms/short/mod.rs b/src/service/rooms/short/mod.rs index a903ef22..9fddf099 100644 --- a/src/service/rooms/short/mod.rs +++ b/src/service/rooms/short/mod.rs @@ -1,5 +1,6 @@ use std::{mem::size_of_val, sync::Arc}; +pub use conduit::pdu::{ShortEventId, ShortId, ShortRoomId}; use conduit::{err, implement, utils, Result}; use database::{Deserialized, Map}; use ruma::{events::StateEventType, EventId, RoomId}; @@ -26,9 +27,6 @@ struct Services { pub type ShortStateHash = ShortId; pub type ShortStateKey = ShortId; -pub type ShortEventId = ShortId; -pub type ShortRoomId = ShortId; -pub type ShortId = u64; impl crate::Service for Service { fn build(args: crate::Args<'_>) -> Result> { @@ -52,7 +50,7 @@ impl crate::Service for Service { #[implement(Service)] pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEventId { - const BUFSIZE: usize = size_of::(); + const BUFSIZE: usize = size_of::(); if let Ok(shorteventid) = self .db @@ -88,7 +86,7 @@ pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> .map(|(i, result)| match result { Ok(ref short) => utils::u64_from_u8(short), Err(_) => { - const BUFSIZE: usize = size_of::(); + const BUFSIZE: usize = size_of::(); let short = self.services.globals.next_count().unwrap(); debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed"); diff --git a/src/service/rooms/spaces/mod.rs b/src/service/rooms/spaces/mod.rs index 5aea5f6a..37272dca 100644 --- a/src/service/rooms/spaces/mod.rs +++ b/src/service/rooms/spaces/mod.rs @@ -33,7 +33,7 @@ use ruma::{ }; use tokio::sync::Mutex; -use crate::{rooms, sending, Dep}; +use crate::{rooms, rooms::short::ShortRoomId, sending, Dep}; pub struct CachedSpaceHierarchySummary { summary: SpaceHierarchyParentSummary, @@ -49,7 +49,7 @@ pub enum SummaryAccessibility { pub struct PaginationToken { /// Path down the hierarchy of the room to start the response at, /// excluding the root space. - pub short_room_ids: Vec, + pub short_room_ids: Vec, pub limit: UInt, pub max_depth: UInt, pub suggested_only: bool, @@ -448,7 +448,7 @@ impl Service { } pub async fn get_client_hierarchy( - &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec, max_depth: u64, + &self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec, max_depth: u64, suggested_only: bool, ) -> Result { let mut parents = VecDeque::new(); diff --git a/src/service/rooms/state/mod.rs b/src/service/rooms/state/mod.rs index 34fab079..71a3900c 100644 --- a/src/service/rooms/state/mod.rs +++ b/src/service/rooms/state/mod.rs @@ -95,7 +95,7 @@ impl Service { let event_ids = statediffnew.iter().stream().filter_map(|new| { self.services .state_compressor - .parse_compressed_state_event(new) + .parse_compressed_state_event(*new) .map_ok_or_else(|_| None, |(_, event_id)| Some(event_id)) }); @@ -428,7 +428,7 @@ impl Service { let Ok((shortstatekey, event_id)) = self .services .state_compressor - .parse_compressed_state_event(compressed) + .parse_compressed_state_event(*compressed) .await else { continue; diff --git a/src/service/rooms/state_accessor/data.rs b/src/service/rooms/state_accessor/data.rs index 9c96785f..06cd648c 100644 --- a/src/service/rooms/state_accessor/data.rs +++ b/src/service/rooms/state_accessor/data.rs @@ -53,7 +53,7 @@ impl Data { let parsed = self .services .state_compressor - .parse_compressed_state_event(compressed) + .parse_compressed_state_event(*compressed) .await?; result.insert(parsed.0, parsed.1); @@ -86,7 +86,7 @@ impl Data { let (_, eventid) = self .services .state_compressor - .parse_compressed_state_event(compressed) + .parse_compressed_state_event(*compressed) .await?; if let Ok(pdu) = self.services.timeline.get_pdu(&eventid).await { @@ -132,7 +132,7 @@ impl Data { self.services .state_compressor - .parse_compressed_state_event(compressed) + .parse_compressed_state_event(*compressed) .map_ok(|(_, id)| id) .map_err(|e| { err!(Database(error!( diff --git a/src/service/rooms/state_accessor/mod.rs b/src/service/rooms/state_accessor/mod.rs index a2cc27e8..d51da8af 100644 --- a/src/service/rooms/state_accessor/mod.rs +++ b/src/service/rooms/state_accessor/mod.rs @@ -39,13 +39,17 @@ use ruma::{ use serde::Deserialize; use self::data::Data; -use crate::{rooms, rooms::state::RoomMutexGuard, Dep}; +use crate::{ + rooms, + rooms::{short::ShortStateHash, state::RoomMutexGuard}, + Dep, +}; pub struct Service { services: Services, db: Data, - pub server_visibility_cache: Mutex>, - pub user_visibility_cache: Mutex>, + pub server_visibility_cache: Mutex>, + pub user_visibility_cache: Mutex>, } struct Services { @@ -94,11 +98,13 @@ impl Service { /// Builds a StateMap by iterating over all keys that start /// with state_hash, this gives the full state for the given state_hash. #[tracing::instrument(skip(self), level = "debug")] - pub async fn state_full_ids(&self, shortstatehash: u64) -> Result>> { + pub async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result>> { self.db.state_full_ids(shortstatehash).await } - pub async fn state_full(&self, shortstatehash: u64) -> Result>> { + pub async fn state_full( + &self, shortstatehash: ShortStateHash, + ) -> Result>> { self.db.state_full(shortstatehash).await } @@ -106,7 +112,7 @@ impl Service { /// `state_key`). #[tracing::instrument(skip(self), level = "debug")] pub async fn state_get_id( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result> { self.db .state_get_id(shortstatehash, event_type, state_key) @@ -117,7 +123,7 @@ impl Service { /// `state_key`). #[inline] pub async fn state_get( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result> { self.db .state_get(shortstatehash, event_type, state_key) @@ -126,7 +132,7 @@ impl Service { /// Returns a single PDU from `room_id` with key (`event_type`,`state_key`). pub async fn state_get_content( - &self, shortstatehash: u64, event_type: &StateEventType, state_key: &str, + &self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str, ) -> Result where T: for<'de> Deserialize<'de> + Send, @@ -137,7 +143,7 @@ impl Service { } /// Get membership for given user in state - async fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> MembershipState { + async fn user_membership(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> MembershipState { self.state_get_content(shortstatehash, &StateEventType::RoomMember, user_id.as_str()) .await .map_or(MembershipState::Leave, |c: RoomMemberEventContent| c.membership) @@ -145,14 +151,14 @@ impl Service { /// The user was a joined member at this state (potentially in the past) #[inline] - async fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool { + async fn user_was_joined(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool { self.user_membership(shortstatehash, user_id).await == MembershipState::Join } /// The user was an invited or joined room member at this state (potentially /// in the past) #[inline] - async fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool { + async fn user_was_invited(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool { let s = self.user_membership(shortstatehash, user_id).await; s == MembershipState::Join || s == MembershipState::Invite } @@ -285,7 +291,7 @@ impl Service { } /// Returns the state hash for this pdu. - pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { + pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result { self.db.pdu_shortstatehash(event_id).await } diff --git a/src/service/rooms/state_compressor/mod.rs b/src/service/rooms/state_compressor/mod.rs index e213490b..bf90d5c4 100644 --- a/src/service/rooms/state_compressor/mod.rs +++ b/src/service/rooms/state_compressor/mod.rs @@ -34,25 +34,26 @@ struct Data { #[derive(Clone)] struct StateDiff { parent: Option, - added: Arc>, - removed: Arc>, + added: Arc, + removed: Arc, } #[derive(Clone, Default)] pub struct ShortStateInfo { pub shortstatehash: ShortStateHash, - pub full_state: Arc>, - pub added: Arc>, - pub removed: Arc>, + pub full_state: Arc, + pub added: Arc, + pub removed: Arc, } #[derive(Clone, Default)] pub struct HashSetCompressStateEvent { pub shortstatehash: ShortStateHash, - pub added: Arc>, - pub removed: Arc>, + pub added: Arc, + pub removed: Arc, } +pub(crate) type CompressedState = HashSet; pub(crate) type CompressedStateEvent = [u8; 2 * size_of::()]; type StateInfoLruCache = LruCache; type ShortStateInfoVec = Vec; @@ -105,7 +106,7 @@ impl Service { removed, } = self.get_statediff(shortstatehash).await?; - if let Some(parent) = parent { + let response = if let Some(parent) = parent { let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?; let mut state = (*response.last().expect("at least one response").full_state).clone(); state.extend(added.iter().copied()); @@ -121,27 +122,22 @@ impl Service { removed: Arc::new(removed), }); - self.stateinfo_cache - .lock() - .expect("locked") - .insert(shortstatehash, response.clone()); - - Ok(response) + response } else { - let response = vec![ShortStateInfo { + vec![ShortStateInfo { shortstatehash, full_state: added.clone(), added, removed, - }]; + }] + }; - self.stateinfo_cache - .lock() - .expect("locked") - .insert(shortstatehash, response.clone()); + self.stateinfo_cache + .lock() + .expect("locked") + .insert(shortstatehash, response.clone()); - Ok(response) - } + Ok(response) } pub async fn compress_state_event(&self, shortstatekey: ShortStateKey, event_id: &EventId) -> CompressedStateEvent { @@ -161,7 +157,7 @@ impl Service { /// Returns shortstatekey, event id #[inline] pub async fn parse_compressed_state_event( - &self, compressed_event: &CompressedStateEvent, + &self, compressed_event: CompressedStateEvent, ) -> Result<(ShortStateKey, Arc)> { use utils::u64_from_u8; diff --git a/src/service/rooms/threads/data.rs b/src/service/rooms/threads/data.rs index f50b812c..c26dabb4 100644 --- a/src/service/rooms/threads/data.rs +++ b/src/service/rooms/threads/data.rs @@ -1,17 +1,22 @@ -use std::{mem::size_of, sync::Arc}; +use std::sync::Arc; use conduit::{ - checked, result::LogErr, - utils, utils::{stream::TryIgnore, ReadyExt}, - PduEvent, Result, + PduCount, PduEvent, Result, }; use database::{Deserialized, Map}; use futures::{Stream, StreamExt}; use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId}; -use crate::{rooms, Dep}; +use crate::{ + rooms, + rooms::{ + short::ShortRoomId, + timeline::{PduId, RawPduId}, + }, + Dep, +}; pub(super) struct Data { threadid_userids: Arc, @@ -35,40 +40,39 @@ impl Data { } } + #[inline] pub(super) async fn threads_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads, - ) -> Result + Send + 'a> { - let prefix = self - .services - .short - .get_shortroomid(room_id) - .await? - .to_be_bytes() - .to_vec(); + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, _include: &'a IncludeThreads, + ) -> Result + Send + 'a> { + let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?; - let mut current = prefix.clone(); - current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes()); + let current: RawPduId = PduId { + shortroomid, + shorteventid: until.saturating_sub(1), + } + .into(); let stream = self .threadid_userids .rev_raw_keys_from(¤t) .ignore_err() - .ready_take_while(move |key| key.starts_with(&prefix)) - .map(|pduid| (utils::u64_from_u8(&pduid[(size_of::())..]), pduid)) - .filter_map(move |(count, pduid)| async move { - let mut pdu = self.services.timeline.get_pdu_from_id(pduid).await.ok()?; + .map(RawPduId::from) + .ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes()) + .filter_map(move |pdu_id| async move { + let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?; + let pdu_id: PduId = pdu_id.into(); if pdu.sender != user_id { pdu.remove_transaction_id().log_err().ok(); } - Some((count, pdu)) + Some((pdu_id.shorteventid, pdu)) }); Ok(stream) } - pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> { + pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result { let users = participants .iter() .map(|user| user.as_bytes()) @@ -80,7 +84,7 @@ impl Data { Ok(()) } - pub(super) async fn get_participants(&self, root_id: &[u8]) -> Result> { - self.threadid_userids.qry(root_id).await.deserialized() + pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result> { + self.threadid_userids.get(root_id).await.deserialized() } } diff --git a/src/service/rooms/threads/mod.rs b/src/service/rooms/threads/mod.rs index 2eafe5d5..02503030 100644 --- a/src/service/rooms/threads/mod.rs +++ b/src/service/rooms/threads/mod.rs @@ -2,7 +2,7 @@ mod data; use std::{collections::BTreeMap, sync::Arc}; -use conduit::{err, PduEvent, Result}; +use conduit::{err, PduCount, PduEvent, Result}; use data::Data; use futures::Stream; use ruma::{ @@ -37,8 +37,8 @@ impl crate::Service for Service { impl Service { pub async fn threads_until<'a>( - &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads, - ) -> Result + Send + 'a> { + &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, include: &'a IncludeThreads, + ) -> Result + Send + 'a> { self.db .threads_until(user_id, room_id, until, include) .await diff --git a/src/service/rooms/timeline/data.rs b/src/service/rooms/timeline/data.rs index 5428a3b9..19dc5325 100644 --- a/src/service/rooms/timeline/data.rs +++ b/src/service/rooms/timeline/data.rs @@ -1,14 +1,13 @@ use std::{ collections::{hash_map, HashMap}, - mem::size_of, sync::Arc, }; use conduit::{ - err, expected, + at, err, result::{LogErr, NotFound}, utils, - utils::{future::TryExtExt, stream::TryIgnore, u64_from_u8, ReadyExt}, + utils::{future::TryExtExt, stream::TryIgnore, ReadyExt}, Err, PduCount, PduEvent, Result, }; use database::{Database, Deserialized, Json, KeyVal, Map}; @@ -16,7 +15,8 @@ use futures::{Stream, StreamExt}; use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId}; use tokio::sync::Mutex; -use crate::{rooms, Dep}; +use super::{PduId, RawPduId}; +use crate::{rooms, rooms::short::ShortRoomId, Dep}; pub(super) struct Data { eventid_outlierpdu: Arc, @@ -58,30 +58,25 @@ impl Data { .lasttimelinecount_cache .lock() .await - .entry(room_id.to_owned()) + .entry(room_id.into()) { - hash_map::Entry::Vacant(v) => { - if let Some(last_count) = self - .pdus_until(sender_user, room_id, PduCount::max()) - .await? - .next() - .await - { - Ok(*v.insert(last_count.0)) - } else { - Ok(PduCount::Normal(0)) - } - }, hash_map::Entry::Occupied(o) => Ok(*o.get()), + hash_map::Entry::Vacant(v) => Ok(self + .pdus_until(sender_user, room_id, PduCount::max()) + .await? + .next() + .await + .map(at!(0)) + .filter(|&count| matches!(count, PduCount::Normal(_))) + .map_or_else(PduCount::max, |count| *v.insert(count))), } } /// Returns the `count` of this pdu's id. pub(super) async fn get_pdu_count(&self, event_id: &EventId) -> Result { - self.eventid_pduid - .get(event_id) + self.get_pdu_id(event_id) .await - .map(|pdu_id| pdu_count(&pdu_id)) + .map(|pdu_id| pdu_id.pdu_count()) } /// Returns the json of a pdu. @@ -102,8 +97,11 @@ impl Data { /// Returns the pdu's id. #[inline] - pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result> { - self.eventid_pduid.get(event_id).await + pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result { + self.eventid_pduid + .get(event_id) + .await + .map(|handle| RawPduId::from(&*handle)) } /// Returns the pdu directly from `eventid_pduid` only. @@ -154,34 +152,40 @@ impl Data { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { + pub(super) async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result { self.pduid_pdu.get(pdu_id).await.deserialized() } /// Returns the pdu as a `BTreeMap`. - pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { + pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result { self.pduid_pdu.get(pdu_id).await.deserialized() } - pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) { + pub(super) async fn append_pdu( + &self, pdu_id: &RawPduId, pdu: &PduEvent, json: &CanonicalJsonObject, count: PduCount, + ) { + debug_assert!(matches!(count, PduCount::Normal(_)), "PduCount not Normal"); + self.pduid_pdu.raw_put(pdu_id, Json(json)); self.lasttimelinecount_cache .lock() .await - .insert(pdu.room_id.clone(), PduCount::Normal(count)); + .insert(pdu.room_id.clone(), count); self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id); self.eventid_outlierpdu.remove(pdu.event_id.as_bytes()); } - pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) { + pub(super) fn prepend_backfill_pdu(&self, pdu_id: &RawPduId, event_id: &EventId, json: &CanonicalJsonObject) { self.pduid_pdu.raw_put(pdu_id, Json(json)); self.eventid_pduid.insert(event_id, pdu_id); self.eventid_outlierpdu.remove(event_id); } /// Removes a pdu and creates a new one with the same id. - pub(super) async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result { + pub(super) async fn replace_pdu( + &self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, _pdu: &PduEvent, + ) -> Result { if self.pduid_pdu.get(pdu_id).await.is_not_found() { return Err!(Request(NotFound("PDU does not exist."))); } @@ -197,13 +201,14 @@ impl Data { pub(super) async fn pdus_until<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, ) -> Result + Send + 'a> { - let (prefix, current) = self.count_to_id(room_id, until, 1, true).await?; + let current = self.count_to_id(room_id, until, true).await?; + let prefix = current.shortroomid(); let stream = self .pduid_pdu .rev_raw_stream_from(¤t) .ignore_err() .ready_take_while(move |(key, _)| key.starts_with(&prefix)) - .map(move |item| Self::each_pdu(item, user_id)); + .map(|item| Self::each_pdu(item, user_id)); Ok(stream) } @@ -211,7 +216,8 @@ impl Data { pub(super) async fn pdus_after<'a>( &'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount, ) -> Result + Send + 'a> { - let (prefix, current) = self.count_to_id(room_id, from, 1, false).await?; + let current = self.count_to_id(room_id, from, false).await?; + let prefix = current.shortroomid(); let stream = self .pduid_pdu .raw_stream_from(¤t) @@ -223,6 +229,8 @@ impl Data { } fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: &UserId) -> PdusIterItem { + let pdu_id: RawPduId = pdu_id.into(); + let mut pdu = serde_json::from_slice::(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON"); @@ -231,9 +239,8 @@ impl Data { } pdu.add_age().log_err().ok(); - let count = pdu_count(pdu_id); - (count, pdu) + (pdu_id.pdu_count(), pdu) } pub(super) fn increment_notification_counts( @@ -256,56 +263,25 @@ impl Data { } } - pub(super) async fn count_to_id( - &self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool, - ) -> Result<(Vec, Vec)> { - let prefix = self + async fn count_to_id(&self, room_id: &RoomId, count: PduCount, subtract: bool) -> Result { + let shortroomid: ShortRoomId = self .services .short .get_shortroomid(room_id) .await - .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))? - .to_be_bytes() - .to_vec(); + .map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?; - let mut pdu_id = prefix.clone(); // +1 so we don't send the base event - let count_raw = match count { - PduCount::Normal(x) => { - if subtract { - x.saturating_sub(offset) - } else { - x.saturating_add(offset) - } - }, - PduCount::Backfilled(x) => { - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - let num = u64::MAX.saturating_sub(x); - if subtract { - num.saturating_sub(offset) - } else { - num.saturating_add(offset) - } + let pdu_id = PduId { + shortroomid, + shorteventid: if subtract { + count.checked_sub(1)? + } else { + count.checked_add(1)? }, }; - pdu_id.extend_from_slice(&count_raw.to_be_bytes()); - Ok((prefix, pdu_id)) - } -} - -/// Returns the `count` of this pdu's id. -pub(super) fn pdu_count(pdu_id: &[u8]) -> PduCount { - const STRIDE: usize = size_of::(); - - let pdu_id_len = pdu_id.len(); - let last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - STRIDE)..]); - let second_last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - 2 * STRIDE)..expected!(pdu_id_len - STRIDE)]); - - if second_last_u64 == 0 { - PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)) - } else { - PduCount::Normal(last_u64) + Ok(pdu_id.into()) } } diff --git a/src/service/rooms/timeline/mod.rs b/src/service/rooms/timeline/mod.rs index e45bf7e5..86a47919 100644 --- a/src/service/rooms/timeline/mod.rs +++ b/src/service/rooms/timeline/mod.rs @@ -1,5 +1,4 @@ mod data; -mod pduid; use std::{ cmp, @@ -15,6 +14,7 @@ use conduit::{ utils::{stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt}, validated, warn, Err, Error, Result, Server, }; +pub use conduit::{PduId, RawPduId}; use futures::{future, future::ready, Future, FutureExt, Stream, StreamExt, TryStreamExt}; use ruma::{ api::federation, @@ -39,13 +39,13 @@ use serde::Deserialize; use serde_json::value::{to_raw_value, RawValue as RawJsonValue}; use self::data::Data; -pub use self::{ - data::PdusIterItem, - pduid::{PduId, RawPduId}, -}; +pub use self::data::PdusIterItem; use crate::{ - account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms, - rooms::state_compressor::CompressedStateEvent, sending, server_keys, users, Dep, + account_data, admin, appservice, + appservice::NamespaceRegex, + globals, pusher, rooms, + rooms::{short::ShortRoomId, state_compressor::CompressedStateEvent}, + sending, server_keys, users, Dep, }; // Update Relationships @@ -229,9 +229,7 @@ impl Service { /// Returns the pdu's id. #[inline] - pub async fn get_pdu_id(&self, event_id: &EventId) -> Result> { - self.db.get_pdu_id(event_id).await - } + pub async fn get_pdu_id(&self, event_id: &EventId) -> Result { self.db.get_pdu_id(event_id).await } /// Returns the pdu. /// @@ -256,16 +254,16 @@ impl Service { /// Returns the pdu. /// /// This does __NOT__ check the outliers `Tree`. - pub async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result { self.db.get_pdu_from_id(pdu_id).await } + pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result { self.db.get_pdu_from_id(pdu_id).await } /// Returns the pdu as a `BTreeMap`. - pub async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result { + pub async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result { self.db.get_pdu_json_from_id(pdu_id).await } /// Removes a pdu and creates a new one with the same id. #[tracing::instrument(skip(self), level = "debug")] - pub async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { + pub async fn replace_pdu(&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> { self.db.replace_pdu(pdu_id, pdu_json, pdu).await } @@ -282,7 +280,7 @@ impl Service { mut pdu_json: CanonicalJsonObject, leaves: Vec, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result> { + ) -> Result { // Coalesce database writes for the remainder of this scope. let _cork = self.db.db.cork_and_flush(); @@ -359,9 +357,12 @@ impl Service { .user .reset_notification_counts(&pdu.sender, &pdu.room_id); - let count2 = self.services.globals.next_count().unwrap(); - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&count2.to_be_bytes()); + let count2 = PduCount::Normal(self.services.globals.next_count().unwrap()); + let pdu_id: RawPduId = PduId { + shortroomid, + shorteventid: count2, + } + .into(); // Insert pdu self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await; @@ -544,7 +545,7 @@ impl Service { if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount); + .add_relation(count2, related_pducount); } } @@ -558,7 +559,7 @@ impl Service { if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await { self.services .pdu_metadata - .add_relation(PduCount::Normal(count2), related_pducount); + .add_relation(count2, related_pducount); } }, Relation::Thread(thread) => { @@ -580,7 +581,7 @@ impl Service { { self.services .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; continue; } @@ -596,7 +597,7 @@ impl Service { if state_key_uid == appservice_uid { self.services .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; continue; } } @@ -623,7 +624,7 @@ impl Service { { self.services .sending - .send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?; + .send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?; } } @@ -935,7 +936,7 @@ impl Service { state_ids_compressed: Arc>, soft_fail: bool, state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex - ) -> Result>> { + ) -> Result> { // We append to state before appending the pdu, so we don't have a moment in // time with the pdu without it's state. This is okay because append_pdu can't // fail. @@ -993,7 +994,7 @@ impl Service { /// Replace a PDU with the redacted form. #[tracing::instrument(skip(self, reason))] - pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> { + pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: ShortRoomId) -> Result { // TODO: Don't reserialize, keep original json let Ok(pdu_id) = self.get_pdu_id(event_id).await else { // If event does not exist, just noop @@ -1133,7 +1134,6 @@ impl Service { // Skip the PDU if we already have it as a timeline event if let Ok(pdu_id) = self.get_pdu_id(&event_id).await { - let pdu_id = pdu_id.to_vec(); debug!("We already know {event_id} at {pdu_id:?}"); return Ok(()); } @@ -1158,11 +1158,13 @@ impl Service { let insert_lock = self.mutex_insert.lock(&room_id).await; - let max = u64::MAX; - let count = self.services.globals.next_count().unwrap(); - let mut pdu_id = shortroomid.to_be_bytes().to_vec(); - pdu_id.extend_from_slice(&0_u64.to_be_bytes()); - pdu_id.extend_from_slice(&(validated!(max - count)).to_be_bytes()); + let count: i64 = self.services.globals.next_count().unwrap().try_into()?; + + let pdu_id: RawPduId = PduId { + shortroomid, + shorteventid: PduCount::Backfilled(validated!(0 - count)), + } + .into(); // Insert pdu self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value); @@ -1246,16 +1248,3 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res Ok(()) } - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn comparisons() { - assert!(PduCount::Normal(1) < PduCount::Normal(2)); - assert!(PduCount::Backfilled(2) < PduCount::Backfilled(1)); - assert!(PduCount::Normal(1) > PduCount::Backfilled(1)); - assert!(PduCount::Backfilled(1) < PduCount::Normal(1)); - } -} diff --git a/src/service/rooms/timeline/pduid.rs b/src/service/rooms/timeline/pduid.rs deleted file mode 100644 index b43c382c..00000000 --- a/src/service/rooms/timeline/pduid.rs +++ /dev/null @@ -1,13 +0,0 @@ -use crate::rooms::short::{ShortEventId, ShortRoomId}; - -#[derive(Clone, Copy)] -pub struct PduId { - _room_id: ShortRoomId, - _event_id: ShortEventId, -} - -pub type RawPduId = [u8; PduId::LEN]; - -impl PduId { - pub const LEN: usize = size_of::() + size_of::(); -} diff --git a/src/service/rooms/user/mod.rs b/src/service/rooms/user/mod.rs index e484203d..99587134 100644 --- a/src/service/rooms/user/mod.rs +++ b/src/service/rooms/user/mod.rs @@ -5,7 +5,7 @@ use database::{Deserialized, Map}; use futures::{pin_mut, Stream, StreamExt}; use ruma::{RoomId, UserId}; -use crate::{globals, rooms, Dep}; +use crate::{globals, rooms, rooms::short::ShortStateHash, Dep}; pub struct Service { db: Data, @@ -93,7 +93,7 @@ pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) - } #[implement(Service)] -pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) { +pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: ShortStateHash) { let shortroomid = self .services .short @@ -108,7 +108,7 @@ pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, } #[implement(Service)] -pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { +pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result { let shortroomid = self.services.short.get_shortroomid(room_id).await?; let key: &[u64] = &[shortroomid, token]; diff --git a/src/service/sending/data.rs b/src/service/sending/data.rs index f75a212c..cd25776a 100644 --- a/src/service/sending/data.rs +++ b/src/service/sending/data.rs @@ -115,10 +115,10 @@ impl Data { let mut keys = Vec::new(); for (event, destination) in requests { let mut key = destination.get_prefix(); - if let SendingEvent::Pdu(value) = &event { - key.extend_from_slice(value); + if let SendingEvent::Pdu(value) = event { + key.extend(value.as_ref()); } else { - key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes()); + key.extend(&self.services.globals.next_count().unwrap().to_be_bytes()); } let value = if let SendingEvent::Edu(value) = &event { &**value @@ -175,7 +175,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se ( Destination::Appservice(server), if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) + SendingEvent::Pdu(event.into()) } else { SendingEvent::Edu(value.to_vec()) }, @@ -202,7 +202,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se ( Destination::Push(user_id, pushkey_string), if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) + SendingEvent::Pdu(event.into()) } else { // I'm pretty sure this should never be called SendingEvent::Edu(value.to_vec()) @@ -225,7 +225,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se .map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?, ), if value.is_empty() { - SendingEvent::Pdu(event.to_vec()) + SendingEvent::Pdu(event.into()) } else { SendingEvent::Edu(value.to_vec()) }, diff --git a/src/service/sending/mod.rs b/src/service/sending/mod.rs index ea266883..77997f69 100644 --- a/src/service/sending/mod.rs +++ b/src/service/sending/mod.rs @@ -24,7 +24,10 @@ pub use self::{ dest::Destination, sender::{EDU_LIMIT, PDU_LIMIT}, }; -use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_keys, users, Dep}; +use crate::{ + account_data, client, globals, presence, pusher, resolver, rooms, rooms::timeline::RawPduId, server_keys, users, + Dep, +}; pub struct Service { server: Arc, @@ -61,9 +64,9 @@ struct Msg { #[allow(clippy::module_name_repetitions)] #[derive(Clone, Debug, PartialEq, Eq, Hash)] pub enum SendingEvent { - Pdu(Vec), // pduid - Edu(Vec), // pdu json - Flush, // none + Pdu(RawPduId), // pduid + Edu(Vec), // pdu json + Flush, // none } #[async_trait] @@ -110,9 +113,9 @@ impl crate::Service for Service { impl Service { #[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")] - pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> { + pub fn send_pdu_push(&self, pdu_id: &RawPduId, user: &UserId, pushkey: String) -> Result { let dest = Destination::Push(user.to_owned(), pushkey); - let event = SendingEvent::Pdu(pdu_id.to_owned()); + let event = SendingEvent::Pdu(*pdu_id); let _cork = self.db.db.cork(); let keys = self.db.queue_requests(&[(&event, &dest)]); self.dispatch(Msg { @@ -123,7 +126,7 @@ impl Service { } #[tracing::instrument(skip(self), level = "debug")] - pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec) -> Result<()> { + pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: RawPduId) -> Result { let dest = Destination::Appservice(appservice_id); let event = SendingEvent::Pdu(pdu_id); let _cork = self.db.db.cork(); @@ -136,7 +139,7 @@ impl Service { } #[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")] - pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> { + pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &RawPduId) -> Result { let servers = self .services .state_cache @@ -147,13 +150,13 @@ impl Service { } #[tracing::instrument(skip(self, servers, pdu_id), level = "debug")] - pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &[u8]) -> Result<()> + pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &RawPduId) -> Result where S: Stream + Send + 'a, { let _cork = self.db.db.cork(); let requests = servers - .map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.into()))) + .map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.to_owned()))) .collect::>() .await; diff --git a/src/service/sending/sender.rs b/src/service/sending/sender.rs index d9087d44..464d186b 100644 --- a/src/service/sending/sender.rs +++ b/src/service/sending/sender.rs @@ -536,7 +536,8 @@ impl Service { &events .iter() .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Edu(b) => &**b, + SendingEvent::Pdu(b) => b.as_ref(), SendingEvent::Flush => &[], }) .collect::>(), @@ -660,7 +661,8 @@ impl Service { &events .iter() .map(|e| match e { - SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b, + SendingEvent::Edu(b) => &**b, + SendingEvent::Pdu(b) => b.as_ref(), SendingEvent::Flush => &[], }) .collect::>(),