refactor for stronger RawPduId type

implement standard traits for PduCount

enable serde for arrayvec

typedef various shortid's

pducount simplifications

split parts of pdu_metadata service to core/pdu and api/relations

remove some yields; improve var names/syntax

tweak types for limit timeline limit arguments

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-11-02 06:12:54 +00:00
parent 2e4d9cb37c
commit 9da523c004
41 changed files with 796 additions and 573 deletions

View file

@ -71,7 +71,7 @@ async fn fresh(services: &Services) -> Result<()> {
db["global"].insert(b"retroactively_fix_bad_data_from_roomuserid_joined", []);
// Create the admin room and server user on first run
crate::admin::create_admin_room(services).await?;
crate::admin::create_admin_room(services).boxed().await?;
warn!(
"Created new {} database with version {DATABASE_VERSION}",

View file

@ -7,9 +7,11 @@ use conduit::{err, utils, utils::math::usize_from_f64, Err, Result};
use database::Map;
use lru_cache::LruCache;
use crate::rooms::short::ShortEventId;
pub(super) struct Data {
shorteventid_authchain: Arc<Map>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[u64]>>>,
pub(super) auth_chain_cache: Mutex<LruCache<Vec<u64>, Arc<[ShortEventId]>>>,
}
impl Data {
@ -24,7 +26,7 @@ impl Data {
}
}
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[u64]>> {
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[ShortEventId]>> {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
// Check RAM cache
@ -63,7 +65,7 @@ impl Data {
Ok(chain)
}
pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[u64]>) {
pub(super) fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: Arc<[ShortEventId]>) {
debug_assert!(!key.is_empty(), "auth_chain key must not be empty");
// Only persist single events in db

View file

@ -10,7 +10,7 @@ use futures::Stream;
use ruma::{EventId, RoomId};
use self::data::Data;
use crate::{rooms, Dep};
use crate::{rooms, rooms::short::ShortEventId, Dep};
pub struct Service {
services: Services,
@ -64,7 +64,7 @@ impl Service {
}
#[tracing::instrument(skip_all, name = "auth_chain")]
pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result<Vec<u64>> {
pub async fn get_auth_chain(&self, room_id: &RoomId, starting_events: &[&EventId]) -> Result<Vec<ShortEventId>> {
const NUM_BUCKETS: usize = 50; //TODO: change possible w/o disrupting db?
const BUCKET: BTreeSet<(u64, &EventId)> = BTreeSet::new();
@ -97,7 +97,7 @@ impl Service {
continue;
}
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
let chunk_key: Vec<ShortEventId> = chunk.iter().map(|(short, _)| short).copied().collect();
if let Ok(cached) = self.get_cached_eventid_authchain(&chunk_key).await {
trace!("Found cache entry for whole chunk");
full_auth_chain.extend(cached.iter().copied());
@ -156,7 +156,7 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id))]
async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> {
async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<ShortEventId>> {
let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new();
@ -195,19 +195,19 @@ impl Service {
}
#[inline]
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[u64]>> {
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Arc<[ShortEventId]>> {
self.db.get_cached_eventid_authchain(key).await
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<u64>) {
let val = auth_chain.iter().copied().collect::<Arc<[u64]>>();
pub fn cache_auth_chain(&self, key: Vec<u64>, auth_chain: &HashSet<ShortEventId>) {
let val = auth_chain.iter().copied().collect::<Arc<[ShortEventId]>>();
self.db.cache_auth_chain(key, val);
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<u64>) {
let val = auth_chain.iter().copied().collect::<Arc<[u64]>>();
pub fn cache_auth_chain_vec(&self, key: Vec<u64>, auth_chain: &Vec<ShortEventId>) {
let val = auth_chain.iter().copied().collect::<Arc<[ShortEventId]>>();
self.db.cache_auth_chain(key, val);
}

View file

@ -35,7 +35,10 @@ use ruma::{
use crate::{
globals, rooms,
rooms::state_compressor::{CompressedStateEvent, HashSetCompressStateEvent},
rooms::{
state_compressor::{CompressedStateEvent, HashSetCompressStateEvent},
timeline::RawPduId,
},
sending, server_keys, Dep,
};
@ -136,10 +139,10 @@ impl Service {
pub async fn handle_incoming_pdu<'a>(
&self, origin: &'a ServerName, room_id: &'a RoomId, event_id: &'a EventId,
value: BTreeMap<String, CanonicalJsonValue>, is_timeline_event: bool,
) -> Result<Option<Vec<u8>>> {
) -> Result<Option<RawPduId>> {
// 1. Skip the PDU if we already have it as a timeline event
if let Ok(pdu_id) = self.services.timeline.get_pdu_id(event_id).await {
return Ok(Some(pdu_id.to_vec()));
return Ok(Some(pdu_id));
}
// 1.1 Check the server is in the room
@ -488,7 +491,7 @@ impl Service {
pub async fn upgrade_outlier_to_timeline_pdu(
&self, incoming_pdu: Arc<PduEvent>, val: BTreeMap<String, CanonicalJsonValue>, create_event: &PduEvent,
origin: &ServerName, room_id: &RoomId,
) -> Result<Option<Vec<u8>>> {
) -> Result<Option<RawPduId>> {
// Skip the PDU if we already have it as a timeline event
if let Ok(pduid) = self
.services
@ -496,7 +499,7 @@ impl Service {
.get_pdu_id(&incoming_pdu.event_id)
.await
{
return Ok(Some(pduid.to_vec()));
return Ok(Some(pduid));
}
if self

View file

@ -2,15 +2,21 @@ use std::{mem::size_of, sync::Arc};
use conduit::{
result::LogErr,
utils,
utils::{stream::TryIgnore, ReadyExt},
utils::{stream::TryIgnore, u64_from_u8, ReadyExt},
PduCount, PduEvent,
};
use database::Map;
use futures::{Stream, StreamExt};
use ruma::{api::Direction, EventId, RoomId, UserId};
use crate::{rooms, Dep};
use crate::{
rooms,
rooms::{
short::{ShortEventId, ShortRoomId},
timeline::{PduId, RawPduId},
},
Dep,
};
pub(super) struct Data {
tofrom_relation: Arc<Map>,
@ -46,35 +52,36 @@ impl Data {
}
pub(super) fn get_relations<'a>(
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount, dir: Direction,
&'a self, user_id: &'a UserId, shortroomid: ShortRoomId, target: ShortEventId, from: PduCount, dir: Direction,
) -> impl Stream<Item = PdusIterItem> + Send + '_ {
let prefix = target.to_be_bytes().to_vec();
let mut current = prefix.clone();
let count_raw = match until {
PduCount::Normal(x) => x.saturating_sub(1),
PduCount::Backfilled(x) => {
current.extend_from_slice(&0_u64.to_be_bytes());
u64::MAX.saturating_sub(x).saturating_sub(1)
},
};
current.extend_from_slice(&count_raw.to_be_bytes());
let current: RawPduId = PduId {
shortroomid,
shorteventid: from,
}
.into();
match dir {
Direction::Forward => self.tofrom_relation.raw_keys_from(&current).boxed(),
Direction::Backward => self.tofrom_relation.rev_raw_keys_from(&current).boxed(),
}
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix))
.map(|to_from| utils::u64_from_u8(&to_from[(size_of::<u64>())..]))
.filter_map(move |from| async move {
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
let mut pdu = self.services.timeline.get_pdu_from_id(&pduid).await.ok()?;
.ready_take_while(move |key| key.starts_with(&target.to_be_bytes()))
.map(|to_from| u64_from_u8(&to_from[8..16]))
.map(PduCount::from_unsigned)
.filter_map(move |shorteventid| async move {
let pdu_id: RawPduId = PduId {
shortroomid,
shorteventid,
}
.into();
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
if pdu.sender != user_id {
pdu.remove_transaction_id().log_err().ok();
}
Some((PduCount::Normal(from), pdu))
Some((shorteventid, pdu))
})
}

View file

@ -1,18 +1,9 @@
mod data;
use std::sync::Arc;
use conduit::{
at,
utils::{result::FlatOk, stream::ReadyExt, IterStream},
PduCount, Result,
};
use futures::{FutureExt, StreamExt};
use ruma::{
api::{client::relations::get_relating_events, Direction},
events::{relation::RelationType, TimelineEventType},
EventId, RoomId, UInt, UserId,
};
use serde::Deserialize;
use conduit::{PduCount, Result};
use futures::StreamExt;
use ruma::{api::Direction, EventId, RoomId, UserId};
use self::data::{Data, PdusIterItem};
use crate::{rooms, Dep};
@ -24,26 +15,14 @@ pub struct Service {
struct Services {
short: Dep<rooms::short::Service>,
state_accessor: Dep<rooms::state_accessor::Service>,
timeline: Dep<rooms::timeline::Service>,
}
#[derive(Clone, Debug, Deserialize)]
struct ExtractRelType {
rel_type: RelationType,
}
#[derive(Clone, Debug, Deserialize)]
struct ExtractRelatesToEventId {
#[serde(rename = "m.relates_to")]
relates_to: ExtractRelType,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
state_accessor: args.depend::<rooms::state_accessor::Service>("rooms::state_accessor"),
timeline: args.depend::<rooms::timeline::Service>("rooms::timeline"),
},
db: Data::new(&args),
@ -64,82 +43,9 @@ impl Service {
}
}
#[allow(clippy::too_many_arguments)]
pub async fn paginate_relations_with_filter(
&self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: Option<TimelineEventType>,
filter_rel_type: Option<RelationType>, from: Option<&str>, to: Option<&str>, limit: Option<UInt>,
recurse: bool, dir: Direction,
) -> Result<get_relating_events::v1::Response> {
let from = from
.map(PduCount::try_from_string)
.transpose()?
.unwrap_or_else(|| match dir {
Direction::Forward => PduCount::min(),
Direction::Backward => PduCount::max(),
});
let to = to.map(PduCount::try_from_string).flat_ok();
// Use limit or else 30, with maximum 100
let limit: usize = limit
.map(TryInto::try_into)
.flat_ok()
.unwrap_or(30)
.min(100);
// Spec (v1.10) recommends depth of at least 3
let depth: u8 = if recurse {
3
} else {
1
};
let events: Vec<PdusIterItem> = self
.get_relations(sender_user, room_id, target, from, limit, depth, dir)
.await
.into_iter()
.filter(|(_, pdu)| {
filter_event_type
.as_ref()
.is_none_or(|kind| *kind == pdu.kind)
})
.filter(|(_, pdu)| {
filter_rel_type.as_ref().is_none_or(|rel_type| {
pdu.get_content()
.map(|c: ExtractRelatesToEventId| c.relates_to.rel_type)
.is_ok_and(|r| r == *rel_type)
})
})
.stream()
.filter_map(|item| self.visibility_filter(sender_user, item))
.ready_take_while(|(count, _)| Some(*count) != to)
.take(limit)
.collect()
.boxed()
.await;
let next_batch = match dir {
Direction::Backward => events.first(),
Direction::Forward => events.last(),
}
.map(at!(0))
.map(|t| t.stringify());
Ok(get_relating_events::v1::Response {
next_batch,
prev_batch: Some(from.stringify()),
recursion_depth: recurse.then_some(depth.into()),
chunk: events
.into_iter()
.map(at!(1))
.map(|pdu| pdu.to_message_like_event())
.collect(),
})
}
#[allow(clippy::too_many_arguments)]
pub async fn get_relations(
&self, user_id: &UserId, room_id: &RoomId, target: &EventId, until: PduCount, limit: usize, max_depth: u8,
&self, user_id: &UserId, room_id: &RoomId, target: &EventId, from: PduCount, limit: usize, max_depth: u8,
dir: Direction,
) -> Vec<PdusIterItem> {
let room_id = self.services.short.get_or_create_shortroomid(room_id).await;
@ -152,7 +58,7 @@ impl Service {
let mut pdus: Vec<_> = self
.db
.get_relations(user_id, room_id, target, until, dir)
.get_relations(user_id, room_id, target, from, dir)
.collect()
.await;
@ -167,7 +73,7 @@ impl Service {
let relations: Vec<_> = self
.db
.get_relations(user_id, room_id, target, until, dir)
.get_relations(user_id, room_id, target, from, dir)
.collect()
.await;
@ -186,16 +92,6 @@ impl Service {
pdus
}
async fn visibility_filter(&self, sender_user: &UserId, item: PdusIterItem) -> Option<PdusIterItem> {
let (_, pdu) = &item;
self.services
.state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
.await
.then_some(item)
}
#[inline]
#[tracing::instrument(skip_all, level = "debug")]
pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) {

View file

@ -1,10 +1,10 @@
use std::{iter, sync::Arc};
use std::sync::Arc;
use arrayvec::ArrayVec;
use conduit::{
implement,
utils::{set, stream::TryIgnore, ArrayVecExt, IterStream, ReadyExt},
PduEvent, Result,
PduCount, PduEvent, Result,
};
use database::{keyval::Val, Map};
use futures::{Stream, StreamExt};
@ -66,13 +66,13 @@ impl crate::Service for Service {
}
#[implement(Service)]
pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
pub fn index_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) {
let batch = tokenize(message_body)
.map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xFF);
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here
(key, Vec::<u8>::new())
})
.collect::<Vec<_>>();
@ -81,12 +81,12 @@ pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
}
#[implement(Service)]
pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
pub fn deindex_pdu(&self, shortroomid: ShortRoomId, pdu_id: &RawPduId, message_body: &str) {
let batch = tokenize(message_body).map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
key.push(0xFF);
key.extend_from_slice(pdu_id); // TODO: currently we save the room id a second time here
key.extend_from_slice(pdu_id.as_ref()); // TODO: currently we save the room id a second time here
key
});
@ -159,24 +159,24 @@ fn search_pdu_ids_query_words<'a>(
&'a self, shortroomid: ShortRoomId, word: &'a str,
) -> impl Stream<Item = RawPduId> + Send + '_ {
self.search_pdu_ids_query_word(shortroomid, word)
.ready_filter_map(move |key| {
key[prefix_len(word)..]
.chunks_exact(PduId::LEN)
.next()
.map(RawPduId::try_from)
.and_then(Result::ok)
.map(move |key| -> RawPduId {
let key = &key[prefix_len(word)..];
key.into()
})
}
/// Iterate over raw database results for a word
#[implement(Service)]
fn search_pdu_ids_query_word(&self, shortroomid: ShortRoomId, word: &str) -> impl Stream<Item = Val<'_>> + Send + '_ {
const PDUID_LEN: usize = PduId::LEN;
// rustc says const'ing this not yet stable
let end_id: ArrayVec<u8, PDUID_LEN> = iter::repeat(u8::MAX).take(PduId::LEN).collect();
let end_id: RawPduId = PduId {
shortroomid,
shorteventid: PduCount::max(),
}
.into();
// Newest pdus first
let end = make_tokenid(shortroomid, word, end_id.as_slice());
let end = make_tokenid(shortroomid, word, &end_id);
let prefix = make_prefix(shortroomid, word);
self.db
.tokenids
@ -196,11 +196,9 @@ fn tokenize(body: &str) -> impl Iterator<Item = String> + Send + '_ {
.map(str::to_lowercase)
}
fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &[u8]) -> TokenId {
debug_assert!(pdu_id.len() == PduId::LEN, "pdu_id size mismatch");
fn make_tokenid(shortroomid: ShortRoomId, word: &str, pdu_id: &RawPduId) -> TokenId {
let mut key = make_prefix(shortroomid, word);
key.extend_from_slice(pdu_id);
key.extend_from_slice(pdu_id.as_ref());
key
}

View file

@ -1,5 +1,6 @@
use std::{mem::size_of_val, sync::Arc};
pub use conduit::pdu::{ShortEventId, ShortId, ShortRoomId};
use conduit::{err, implement, utils, Result};
use database::{Deserialized, Map};
use ruma::{events::StateEventType, EventId, RoomId};
@ -26,9 +27,6 @@ struct Services {
pub type ShortStateHash = ShortId;
pub type ShortStateKey = ShortId;
pub type ShortEventId = ShortId;
pub type ShortRoomId = ShortId;
pub type ShortId = u64;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
@ -52,7 +50,7 @@ impl crate::Service for Service {
#[implement(Service)]
pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> ShortEventId {
const BUFSIZE: usize = size_of::<u64>();
const BUFSIZE: usize = size_of::<ShortEventId>();
if let Ok(shorteventid) = self
.db
@ -88,7 +86,7 @@ pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) ->
.map(|(i, result)| match result {
Ok(ref short) => utils::u64_from_u8(short),
Err(_) => {
const BUFSIZE: usize = size_of::<u64>();
const BUFSIZE: usize = size_of::<ShortEventId>();
let short = self.services.globals.next_count().unwrap();
debug_assert!(size_of_val(&short) == BUFSIZE, "buffer requirement changed");

View file

@ -33,7 +33,7 @@ use ruma::{
};
use tokio::sync::Mutex;
use crate::{rooms, sending, Dep};
use crate::{rooms, rooms::short::ShortRoomId, sending, Dep};
pub struct CachedSpaceHierarchySummary {
summary: SpaceHierarchyParentSummary,
@ -49,7 +49,7 @@ pub enum SummaryAccessibility {
pub struct PaginationToken {
/// Path down the hierarchy of the room to start the response at,
/// excluding the root space.
pub short_room_ids: Vec<u64>,
pub short_room_ids: Vec<ShortRoomId>,
pub limit: UInt,
pub max_depth: UInt,
pub suggested_only: bool,
@ -448,7 +448,7 @@ impl Service {
}
pub async fn get_client_hierarchy(
&self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec<u64>, max_depth: u64,
&self, sender_user: &UserId, room_id: &RoomId, limit: usize, short_room_ids: Vec<ShortRoomId>, max_depth: u64,
suggested_only: bool,
) -> Result<client::space::get_hierarchy::v1::Response> {
let mut parents = VecDeque::new();

View file

@ -95,7 +95,7 @@ impl Service {
let event_ids = statediffnew.iter().stream().filter_map(|new| {
self.services
.state_compressor
.parse_compressed_state_event(new)
.parse_compressed_state_event(*new)
.map_ok_or_else(|_| None, |(_, event_id)| Some(event_id))
});
@ -428,7 +428,7 @@ impl Service {
let Ok((shortstatekey, event_id)) = self
.services
.state_compressor
.parse_compressed_state_event(compressed)
.parse_compressed_state_event(*compressed)
.await
else {
continue;

View file

@ -53,7 +53,7 @@ impl Data {
let parsed = self
.services
.state_compressor
.parse_compressed_state_event(compressed)
.parse_compressed_state_event(*compressed)
.await?;
result.insert(parsed.0, parsed.1);
@ -86,7 +86,7 @@ impl Data {
let (_, eventid) = self
.services
.state_compressor
.parse_compressed_state_event(compressed)
.parse_compressed_state_event(*compressed)
.await?;
if let Ok(pdu) = self.services.timeline.get_pdu(&eventid).await {
@ -132,7 +132,7 @@ impl Data {
self.services
.state_compressor
.parse_compressed_state_event(compressed)
.parse_compressed_state_event(*compressed)
.map_ok(|(_, id)| id)
.map_err(|e| {
err!(Database(error!(

View file

@ -39,13 +39,17 @@ use ruma::{
use serde::Deserialize;
use self::data::Data;
use crate::{rooms, rooms::state::RoomMutexGuard, Dep};
use crate::{
rooms,
rooms::{short::ShortStateHash, state::RoomMutexGuard},
Dep,
};
pub struct Service {
services: Services,
db: Data,
pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, u64), bool>>,
pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, u64), bool>>,
pub server_visibility_cache: Mutex<LruCache<(OwnedServerName, ShortStateHash), bool>>,
pub user_visibility_cache: Mutex<LruCache<(OwnedUserId, ShortStateHash), bool>>,
}
struct Services {
@ -94,11 +98,13 @@ impl Service {
/// Builds a StateMap by iterating over all keys that start
/// with state_hash, this gives the full state for the given state_hash.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_full_ids(&self, shortstatehash: u64) -> Result<HashMap<u64, Arc<EventId>>> {
pub async fn state_full_ids(&self, shortstatehash: ShortStateHash) -> Result<HashMap<u64, Arc<EventId>>> {
self.db.state_full_ids(shortstatehash).await
}
pub async fn state_full(&self, shortstatehash: u64) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
pub async fn state_full(
&self, shortstatehash: ShortStateHash,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
self.db.state_full(shortstatehash).await
}
@ -106,7 +112,7 @@ impl Service {
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub async fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<Arc<EventId>> {
self.db
.state_get_id(shortstatehash, event_type, state_key)
@ -117,7 +123,7 @@ impl Service {
/// `state_key`).
#[inline]
pub async fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<Arc<PduEvent>> {
self.db
.state_get(shortstatehash, event_type, state_key)
@ -126,7 +132,7 @@ impl Service {
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub async fn state_get_content<T>(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
&self, shortstatehash: ShortStateHash, event_type: &StateEventType, state_key: &str,
) -> Result<T>
where
T: for<'de> Deserialize<'de> + Send,
@ -137,7 +143,7 @@ impl Service {
}
/// Get membership for given user in state
async fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> MembershipState {
async fn user_membership(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> MembershipState {
self.state_get_content(shortstatehash, &StateEventType::RoomMember, user_id.as_str())
.await
.map_or(MembershipState::Leave, |c: RoomMemberEventContent| c.membership)
@ -145,14 +151,14 @@ impl Service {
/// The user was a joined member at this state (potentially in the past)
#[inline]
async fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool {
async fn user_was_joined(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id).await == MembershipState::Join
}
/// The user was an invited or joined room member at this state (potentially
/// in the past)
#[inline]
async fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool {
async fn user_was_invited(&self, shortstatehash: ShortStateHash, user_id: &UserId) -> bool {
let s = self.user_membership(shortstatehash, user_id).await;
s == MembershipState::Join || s == MembershipState::Invite
}
@ -285,7 +291,7 @@ impl Service {
}
/// Returns the state hash for this pdu.
pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<u64> {
pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<ShortStateHash> {
self.db.pdu_shortstatehash(event_id).await
}

View file

@ -34,25 +34,26 @@ struct Data {
#[derive(Clone)]
struct StateDiff {
parent: Option<u64>,
added: Arc<HashSet<CompressedStateEvent>>,
removed: Arc<HashSet<CompressedStateEvent>>,
added: Arc<CompressedState>,
removed: Arc<CompressedState>,
}
#[derive(Clone, Default)]
pub struct ShortStateInfo {
pub shortstatehash: ShortStateHash,
pub full_state: Arc<HashSet<CompressedStateEvent>>,
pub added: Arc<HashSet<CompressedStateEvent>>,
pub removed: Arc<HashSet<CompressedStateEvent>>,
pub full_state: Arc<CompressedState>,
pub added: Arc<CompressedState>,
pub removed: Arc<CompressedState>,
}
#[derive(Clone, Default)]
pub struct HashSetCompressStateEvent {
pub shortstatehash: ShortStateHash,
pub added: Arc<HashSet<CompressedStateEvent>>,
pub removed: Arc<HashSet<CompressedStateEvent>>,
pub added: Arc<CompressedState>,
pub removed: Arc<CompressedState>,
}
pub(crate) type CompressedState = HashSet<CompressedStateEvent>;
pub(crate) type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
type StateInfoLruCache = LruCache<ShortStateHash, ShortStateInfoVec>;
type ShortStateInfoVec = Vec<ShortStateInfo>;
@ -105,7 +106,7 @@ impl Service {
removed,
} = self.get_statediff(shortstatehash).await?;
if let Some(parent) = parent {
let response = if let Some(parent) = parent {
let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?;
let mut state = (*response.last().expect("at least one response").full_state).clone();
state.extend(added.iter().copied());
@ -121,27 +122,22 @@ impl Service {
removed: Arc::new(removed),
});
self.stateinfo_cache
.lock()
.expect("locked")
.insert(shortstatehash, response.clone());
Ok(response)
response
} else {
let response = vec![ShortStateInfo {
vec![ShortStateInfo {
shortstatehash,
full_state: added.clone(),
added,
removed,
}];
}]
};
self.stateinfo_cache
.lock()
.expect("locked")
.insert(shortstatehash, response.clone());
self.stateinfo_cache
.lock()
.expect("locked")
.insert(shortstatehash, response.clone());
Ok(response)
}
Ok(response)
}
pub async fn compress_state_event(&self, shortstatekey: ShortStateKey, event_id: &EventId) -> CompressedStateEvent {
@ -161,7 +157,7 @@ impl Service {
/// Returns shortstatekey, event id
#[inline]
pub async fn parse_compressed_state_event(
&self, compressed_event: &CompressedStateEvent,
&self, compressed_event: CompressedStateEvent,
) -> Result<(ShortStateKey, Arc<EventId>)> {
use utils::u64_from_u8;

View file

@ -1,17 +1,22 @@
use std::{mem::size_of, sync::Arc};
use std::sync::Arc;
use conduit::{
checked,
result::LogErr,
utils,
utils::{stream::TryIgnore, ReadyExt},
PduEvent, Result,
PduCount, PduEvent, Result,
};
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use crate::{rooms, Dep};
use crate::{
rooms,
rooms::{
short::ShortRoomId,
timeline::{PduId, RawPduId},
},
Dep,
};
pub(super) struct Data {
threadid_userids: Arc<Map>,
@ -35,40 +40,39 @@ impl Data {
}
}
#[inline]
pub(super) async fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> Result<impl Stream<Item = (u64, PduEvent)> + Send + 'a> {
let prefix = self
.services
.short
.get_shortroomid(room_id)
.await?
.to_be_bytes()
.to_vec();
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, _include: &'a IncludeThreads,
) -> Result<impl Stream<Item = (PduCount, PduEvent)> + Send + 'a> {
let shortroomid: ShortRoomId = self.services.short.get_shortroomid(room_id).await?;
let mut current = prefix.clone();
current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes());
let current: RawPduId = PduId {
shortroomid,
shorteventid: until.saturating_sub(1),
}
.into();
let stream = self
.threadid_userids
.rev_raw_keys_from(&current)
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix))
.map(|pduid| (utils::u64_from_u8(&pduid[(size_of::<u64>())..]), pduid))
.filter_map(move |(count, pduid)| async move {
let mut pdu = self.services.timeline.get_pdu_from_id(pduid).await.ok()?;
.map(RawPduId::from)
.ready_take_while(move |pdu_id| pdu_id.shortroomid() == shortroomid.to_be_bytes())
.filter_map(move |pdu_id| async move {
let mut pdu = self.services.timeline.get_pdu_from_id(&pdu_id).await.ok()?;
let pdu_id: PduId = pdu_id.into();
if pdu.sender != user_id {
pdu.remove_transaction_id().log_err().ok();
}
Some((count, pdu))
Some((pdu_id.shorteventid, pdu))
});
Ok(stream)
}
pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
pub(super) fn update_participants(&self, root_id: &RawPduId, participants: &[OwnedUserId]) -> Result {
let users = participants
.iter()
.map(|user| user.as_bytes())
@ -80,7 +84,7 @@ impl Data {
Ok(())
}
pub(super) async fn get_participants(&self, root_id: &[u8]) -> Result<Vec<OwnedUserId>> {
self.threadid_userids.qry(root_id).await.deserialized()
pub(super) async fn get_participants(&self, root_id: &RawPduId) -> Result<Vec<OwnedUserId>> {
self.threadid_userids.get(root_id).await.deserialized()
}
}

View file

@ -2,7 +2,7 @@ mod data;
use std::{collections::BTreeMap, sync::Arc};
use conduit::{err, PduEvent, Result};
use conduit::{err, PduCount, PduEvent, Result};
use data::Data;
use futures::Stream;
use ruma::{
@ -37,8 +37,8 @@ impl crate::Service for Service {
impl Service {
pub async fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads,
) -> Result<impl Stream<Item = (u64, PduEvent)> + Send + 'a> {
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount, include: &'a IncludeThreads,
) -> Result<impl Stream<Item = (PduCount, PduEvent)> + Send + 'a> {
self.db
.threads_until(user_id, room_id, until, include)
.await

View file

@ -1,14 +1,13 @@
use std::{
collections::{hash_map, HashMap},
mem::size_of,
sync::Arc,
};
use conduit::{
err, expected,
at, err,
result::{LogErr, NotFound},
utils,
utils::{future::TryExtExt, stream::TryIgnore, u64_from_u8, ReadyExt},
utils::{future::TryExtExt, stream::TryIgnore, ReadyExt},
Err, PduCount, PduEvent, Result,
};
use database::{Database, Deserialized, Json, KeyVal, Map};
@ -16,7 +15,8 @@ use futures::{Stream, StreamExt};
use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use tokio::sync::Mutex;
use crate::{rooms, Dep};
use super::{PduId, RawPduId};
use crate::{rooms, rooms::short::ShortRoomId, Dep};
pub(super) struct Data {
eventid_outlierpdu: Arc<Map>,
@ -58,30 +58,25 @@ impl Data {
.lasttimelinecount_cache
.lock()
.await
.entry(room_id.to_owned())
.entry(room_id.into())
{
hash_map::Entry::Vacant(v) => {
if let Some(last_count) = self
.pdus_until(sender_user, room_id, PduCount::max())
.await?
.next()
.await
{
Ok(*v.insert(last_count.0))
} else {
Ok(PduCount::Normal(0))
}
},
hash_map::Entry::Occupied(o) => Ok(*o.get()),
hash_map::Entry::Vacant(v) => Ok(self
.pdus_until(sender_user, room_id, PduCount::max())
.await?
.next()
.await
.map(at!(0))
.filter(|&count| matches!(count, PduCount::Normal(_)))
.map_or_else(PduCount::max, |count| *v.insert(count))),
}
}
/// Returns the `count` of this pdu's id.
pub(super) async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> {
self.eventid_pduid
.get(event_id)
self.get_pdu_id(event_id)
.await
.map(|pdu_id| pdu_count(&pdu_id))
.map(|pdu_id| pdu_id.pdu_count())
}
/// Returns the json of a pdu.
@ -102,8 +97,11 @@ impl Data {
/// Returns the pdu's id.
#[inline]
pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result<database::Handle<'_>> {
self.eventid_pduid.get(event_id).await
pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> {
self.eventid_pduid
.get(event_id)
.await
.map(|handle| RawPduId::from(&*handle))
}
/// Returns the pdu directly from `eventid_pduid` only.
@ -154,34 +152,40 @@ impl Data {
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<PduEvent> {
pub(super) async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> {
self.pduid_pdu.get(pdu_id).await.deserialized()
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<CanonicalJsonObject> {
pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> {
self.pduid_pdu.get(pdu_id).await.deserialized()
}
pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) {
pub(super) async fn append_pdu(
&self, pdu_id: &RawPduId, pdu: &PduEvent, json: &CanonicalJsonObject, count: PduCount,
) {
debug_assert!(matches!(count, PduCount::Normal(_)), "PduCount not Normal");
self.pduid_pdu.raw_put(pdu_id, Json(json));
self.lasttimelinecount_cache
.lock()
.await
.insert(pdu.room_id.clone(), PduCount::Normal(count));
.insert(pdu.room_id.clone(), count);
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id);
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes());
}
pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) {
pub(super) fn prepend_backfill_pdu(&self, pdu_id: &RawPduId, event_id: &EventId, json: &CanonicalJsonObject) {
self.pduid_pdu.raw_put(pdu_id, Json(json));
self.eventid_pduid.insert(event_id, pdu_id);
self.eventid_outlierpdu.remove(event_id);
}
/// Removes a pdu and creates a new one with the same id.
pub(super) async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result {
pub(super) async fn replace_pdu(
&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, _pdu: &PduEvent,
) -> Result {
if self.pduid_pdu.get(pdu_id).await.is_not_found() {
return Err!(Request(NotFound("PDU does not exist.")));
}
@ -197,13 +201,14 @@ impl Data {
pub(super) async fn pdus_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
let (prefix, current) = self.count_to_id(room_id, until, 1, true).await?;
let current = self.count_to_id(room_id, until, true).await?;
let prefix = current.shortroomid();
let stream = self
.pduid_pdu
.rev_raw_stream_from(&current)
.ignore_err()
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
.map(move |item| Self::each_pdu(item, user_id));
.map(|item| Self::each_pdu(item, user_id));
Ok(stream)
}
@ -211,7 +216,8 @@ impl Data {
pub(super) async fn pdus_after<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
let (prefix, current) = self.count_to_id(room_id, from, 1, false).await?;
let current = self.count_to_id(room_id, from, false).await?;
let prefix = current.shortroomid();
let stream = self
.pduid_pdu
.raw_stream_from(&current)
@ -223,6 +229,8 @@ impl Data {
}
fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: &UserId) -> PdusIterItem {
let pdu_id: RawPduId = pdu_id.into();
let mut pdu =
serde_json::from_slice::<PduEvent>(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON");
@ -231,9 +239,8 @@ impl Data {
}
pdu.add_age().log_err().ok();
let count = pdu_count(pdu_id);
(count, pdu)
(pdu_id.pdu_count(), pdu)
}
pub(super) fn increment_notification_counts(
@ -256,56 +263,25 @@ impl Data {
}
}
pub(super) async fn count_to_id(
&self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool,
) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = self
async fn count_to_id(&self, room_id: &RoomId, count: PduCount, subtract: bool) -> Result<RawPduId> {
let shortroomid: ShortRoomId = self
.services
.short
.get_shortroomid(room_id)
.await
.map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?
.to_be_bytes()
.to_vec();
.map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?;
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
PduCount::Normal(x) => {
if subtract {
x.saturating_sub(offset)
} else {
x.saturating_add(offset)
}
},
PduCount::Backfilled(x) => {
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
let num = u64::MAX.saturating_sub(x);
if subtract {
num.saturating_sub(offset)
} else {
num.saturating_add(offset)
}
let pdu_id = PduId {
shortroomid,
shorteventid: if subtract {
count.checked_sub(1)?
} else {
count.checked_add(1)?
},
};
pdu_id.extend_from_slice(&count_raw.to_be_bytes());
Ok((prefix, pdu_id))
}
}
/// Returns the `count` of this pdu's id.
pub(super) fn pdu_count(pdu_id: &[u8]) -> PduCount {
const STRIDE: usize = size_of::<u64>();
let pdu_id_len = pdu_id.len();
let last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - STRIDE)..]);
let second_last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - 2 * STRIDE)..expected!(pdu_id_len - STRIDE)]);
if second_last_u64 == 0 {
PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))
} else {
PduCount::Normal(last_u64)
Ok(pdu_id.into())
}
}

View file

@ -1,5 +1,4 @@
mod data;
mod pduid;
use std::{
cmp,
@ -15,6 +14,7 @@ use conduit::{
utils::{stream::TryIgnore, IterStream, MutexMap, MutexMapGuard, ReadyExt},
validated, warn, Err, Error, Result, Server,
};
pub use conduit::{PduId, RawPduId};
use futures::{future, future::ready, Future, FutureExt, Stream, StreamExt, TryStreamExt};
use ruma::{
api::federation,
@ -39,13 +39,13 @@ use serde::Deserialize;
use serde_json::value::{to_raw_value, RawValue as RawJsonValue};
use self::data::Data;
pub use self::{
data::PdusIterItem,
pduid::{PduId, RawPduId},
};
pub use self::data::PdusIterItem;
use crate::{
account_data, admin, appservice, appservice::NamespaceRegex, globals, pusher, rooms,
rooms::state_compressor::CompressedStateEvent, sending, server_keys, users, Dep,
account_data, admin, appservice,
appservice::NamespaceRegex,
globals, pusher, rooms,
rooms::{short::ShortRoomId, state_compressor::CompressedStateEvent},
sending, server_keys, users, Dep,
};
// Update Relationships
@ -229,9 +229,7 @@ impl Service {
/// Returns the pdu's id.
#[inline]
pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<database::Handle<'_>> {
self.db.get_pdu_id(event_id).await
}
pub async fn get_pdu_id(&self, event_id: &EventId) -> Result<RawPduId> { self.db.get_pdu_id(event_id).await }
/// Returns the pdu.
///
@ -256,16 +254,16 @@ impl Service {
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
pub async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<PduEvent> { self.db.get_pdu_from_id(pdu_id).await }
pub async fn get_pdu_from_id(&self, pdu_id: &RawPduId) -> Result<PduEvent> { self.db.get_pdu_from_id(pdu_id).await }
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
pub async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<CanonicalJsonObject> {
pub async fn get_pdu_json_from_id(&self, pdu_id: &RawPduId) -> Result<CanonicalJsonObject> {
self.db.get_pdu_json_from_id(pdu_id).await
}
/// Removes a pdu and creates a new one with the same id.
#[tracing::instrument(skip(self), level = "debug")]
pub async fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> {
pub async fn replace_pdu(&self, pdu_id: &RawPduId, pdu_json: &CanonicalJsonObject, pdu: &PduEvent) -> Result<()> {
self.db.replace_pdu(pdu_id, pdu_json, pdu).await
}
@ -282,7 +280,7 @@ impl Service {
mut pdu_json: CanonicalJsonObject,
leaves: Vec<OwnedEventId>,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<Vec<u8>> {
) -> Result<RawPduId> {
// Coalesce database writes for the remainder of this scope.
let _cork = self.db.db.cork_and_flush();
@ -359,9 +357,12 @@ impl Service {
.user
.reset_notification_counts(&pdu.sender, &pdu.room_id);
let count2 = self.services.globals.next_count().unwrap();
let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&count2.to_be_bytes());
let count2 = PduCount::Normal(self.services.globals.next_count().unwrap());
let pdu_id: RawPduId = PduId {
shortroomid,
shorteventid: count2,
}
.into();
// Insert pdu
self.db.append_pdu(&pdu_id, pdu, &pdu_json, count2).await;
@ -544,7 +545,7 @@ impl Service {
if let Ok(related_pducount) = self.get_pdu_count(&content.relates_to.event_id).await {
self.services
.pdu_metadata
.add_relation(PduCount::Normal(count2), related_pducount);
.add_relation(count2, related_pducount);
}
}
@ -558,7 +559,7 @@ impl Service {
if let Ok(related_pducount) = self.get_pdu_count(&in_reply_to.event_id).await {
self.services
.pdu_metadata
.add_relation(PduCount::Normal(count2), related_pducount);
.add_relation(count2, related_pducount);
}
},
Relation::Thread(thread) => {
@ -580,7 +581,7 @@ impl Service {
{
self.services
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?;
continue;
}
@ -596,7 +597,7 @@ impl Service {
if state_key_uid == appservice_uid {
self.services
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?;
continue;
}
}
@ -623,7 +624,7 @@ impl Service {
{
self.services
.sending
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id.clone())?;
.send_pdu_appservice(appservice.registration.id.clone(), pdu_id)?;
}
}
@ -935,7 +936,7 @@ impl Service {
state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
soft_fail: bool,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<Option<Vec<u8>>> {
) -> Result<Option<RawPduId>> {
// We append to state before appending the pdu, so we don't have a moment in
// time with the pdu without it's state. This is okay because append_pdu can't
// fail.
@ -993,7 +994,7 @@ impl Service {
/// Replace a PDU with the redacted form.
#[tracing::instrument(skip(self, reason))]
pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: u64) -> Result<()> {
pub async fn redact_pdu(&self, event_id: &EventId, reason: &PduEvent, shortroomid: ShortRoomId) -> Result {
// TODO: Don't reserialize, keep original json
let Ok(pdu_id) = self.get_pdu_id(event_id).await else {
// If event does not exist, just noop
@ -1133,7 +1134,6 @@ impl Service {
// Skip the PDU if we already have it as a timeline event
if let Ok(pdu_id) = self.get_pdu_id(&event_id).await {
let pdu_id = pdu_id.to_vec();
debug!("We already know {event_id} at {pdu_id:?}");
return Ok(());
}
@ -1158,11 +1158,13 @@ impl Service {
let insert_lock = self.mutex_insert.lock(&room_id).await;
let max = u64::MAX;
let count = self.services.globals.next_count().unwrap();
let mut pdu_id = shortroomid.to_be_bytes().to_vec();
pdu_id.extend_from_slice(&0_u64.to_be_bytes());
pdu_id.extend_from_slice(&(validated!(max - count)).to_be_bytes());
let count: i64 = self.services.globals.next_count().unwrap().try_into()?;
let pdu_id: RawPduId = PduId {
shortroomid,
shorteventid: PduCount::Backfilled(validated!(0 - count)),
}
.into();
// Insert pdu
self.db.prepend_backfill_pdu(&pdu_id, &event_id, &value);
@ -1246,16 +1248,3 @@ async fn check_pdu_for_admin_room(&self, pdu: &PduEvent, sender: &UserId) -> Res
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn comparisons() {
assert!(PduCount::Normal(1) < PduCount::Normal(2));
assert!(PduCount::Backfilled(2) < PduCount::Backfilled(1));
assert!(PduCount::Normal(1) > PduCount::Backfilled(1));
assert!(PduCount::Backfilled(1) < PduCount::Normal(1));
}
}

View file

@ -1,13 +0,0 @@
use crate::rooms::short::{ShortEventId, ShortRoomId};
#[derive(Clone, Copy)]
pub struct PduId {
_room_id: ShortRoomId,
_event_id: ShortEventId,
}
pub type RawPduId = [u8; PduId::LEN];
impl PduId {
pub const LEN: usize = size_of::<ShortRoomId>() + size_of::<ShortEventId>();
}

View file

@ -5,7 +5,7 @@ use database::{Deserialized, Map};
use futures::{pin_mut, Stream, StreamExt};
use ruma::{RoomId, UserId};
use crate::{globals, rooms, Dep};
use crate::{globals, rooms, rooms::short::ShortStateHash, Dep};
pub struct Service {
db: Data,
@ -93,7 +93,7 @@ pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -
}
#[implement(Service)]
pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) {
pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: ShortStateHash) {
let shortroomid = self
.services
.short
@ -108,7 +108,7 @@ pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64,
}
#[implement(Service)]
pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<u64> {
pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<ShortStateHash> {
let shortroomid = self.services.short.get_shortroomid(room_id).await?;
let key: &[u64] = &[shortroomid, token];

View file

@ -115,10 +115,10 @@ impl Data {
let mut keys = Vec::new();
for (event, destination) in requests {
let mut key = destination.get_prefix();
if let SendingEvent::Pdu(value) = &event {
key.extend_from_slice(value);
if let SendingEvent::Pdu(value) = event {
key.extend(value.as_ref());
} else {
key.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes());
key.extend(&self.services.globals.next_count().unwrap().to_be_bytes());
}
let value = if let SendingEvent::Edu(value) = &event {
&**value
@ -175,7 +175,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
(
Destination::Appservice(server),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
SendingEvent::Pdu(event.into())
} else {
SendingEvent::Edu(value.to_vec())
},
@ -202,7 +202,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
(
Destination::Push(user_id, pushkey_string),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
SendingEvent::Pdu(event.into())
} else {
// I'm pretty sure this should never be called
SendingEvent::Edu(value.to_vec())
@ -225,7 +225,7 @@ fn parse_servercurrentevent(key: &[u8], value: &[u8]) -> Result<(Destination, Se
.map_err(|_| Error::bad_database("Invalid server string in server_currenttransaction"))?,
),
if value.is_empty() {
SendingEvent::Pdu(event.to_vec())
SendingEvent::Pdu(event.into())
} else {
SendingEvent::Edu(value.to_vec())
},

View file

@ -24,7 +24,10 @@ pub use self::{
dest::Destination,
sender::{EDU_LIMIT, PDU_LIMIT},
};
use crate::{account_data, client, globals, presence, pusher, resolver, rooms, server_keys, users, Dep};
use crate::{
account_data, client, globals, presence, pusher, resolver, rooms, rooms::timeline::RawPduId, server_keys, users,
Dep,
};
pub struct Service {
server: Arc<Server>,
@ -61,9 +64,9 @@ struct Msg {
#[allow(clippy::module_name_repetitions)]
#[derive(Clone, Debug, PartialEq, Eq, Hash)]
pub enum SendingEvent {
Pdu(Vec<u8>), // pduid
Edu(Vec<u8>), // pdu json
Flush, // none
Pdu(RawPduId), // pduid
Edu(Vec<u8>), // pdu json
Flush, // none
}
#[async_trait]
@ -110,9 +113,9 @@ impl crate::Service for Service {
impl Service {
#[tracing::instrument(skip(self, pdu_id, user, pushkey), level = "debug")]
pub fn send_pdu_push(&self, pdu_id: &[u8], user: &UserId, pushkey: String) -> Result<()> {
pub fn send_pdu_push(&self, pdu_id: &RawPduId, user: &UserId, pushkey: String) -> Result {
let dest = Destination::Push(user.to_owned(), pushkey);
let event = SendingEvent::Pdu(pdu_id.to_owned());
let event = SendingEvent::Pdu(*pdu_id);
let _cork = self.db.db.cork();
let keys = self.db.queue_requests(&[(&event, &dest)]);
self.dispatch(Msg {
@ -123,7 +126,7 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: Vec<u8>) -> Result<()> {
pub fn send_pdu_appservice(&self, appservice_id: String, pdu_id: RawPduId) -> Result {
let dest = Destination::Appservice(appservice_id);
let event = SendingEvent::Pdu(pdu_id);
let _cork = self.db.db.cork();
@ -136,7 +139,7 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id, pdu_id), level = "debug")]
pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &[u8]) -> Result<()> {
pub async fn send_pdu_room(&self, room_id: &RoomId, pdu_id: &RawPduId) -> Result {
let servers = self
.services
.state_cache
@ -147,13 +150,13 @@ impl Service {
}
#[tracing::instrument(skip(self, servers, pdu_id), level = "debug")]
pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &[u8]) -> Result<()>
pub async fn send_pdu_servers<'a, S>(&self, servers: S, pdu_id: &RawPduId) -> Result
where
S: Stream<Item = &'a ServerName> + Send + 'a,
{
let _cork = self.db.db.cork();
let requests = servers
.map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.into())))
.map(|server| (Destination::Normal(server.into()), SendingEvent::Pdu(pdu_id.to_owned())))
.collect::<Vec<_>>()
.await;

View file

@ -536,7 +536,8 @@ impl Service {
&events
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Edu(b) => &**b,
SendingEvent::Pdu(b) => b.as_ref(),
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),
@ -660,7 +661,8 @@ impl Service {
&events
.iter()
.map(|e| match e {
SendingEvent::Edu(b) | SendingEvent::Pdu(b) => &**b,
SendingEvent::Edu(b) => &**b,
SendingEvent::Pdu(b) => b.as_ref(),
SendingEvent::Flush => &[],
})
.collect::<Vec<_>>(),