Database Refactor

combine service/users data w/ mod unit

split sliding sync related out of service/users

instrument database entry points

remove increment crap from database interface

de-wrap all database get() calls

de-wrap all database insert() calls

de-wrap all database remove() calls

refactor database interface for async streaming

add query key serializer for database

implement Debug for result handle

add query deserializer for database

add deserialization trait for option handle

start a stream utils suite

de-wrap/asyncify/type-query count_one_time_keys()

de-wrap/asyncify users count

add admin query users command suite

de-wrap/asyncify users exists

de-wrap/partially asyncify user filter related

asyncify/de-wrap users device/keys related

asyncify/de-wrap user auth/misc related

asyncify/de-wrap users blurhash

asyncify/de-wrap account_data get; merge Data into Service

partial asyncify/de-wrap uiaa; merge Data into Service

partially asyncify/de-wrap transaction_ids get; merge Data into Service

partially asyncify/de-wrap key_backups; merge Data into Service

asyncify/de-wrap pusher service getters; merge Data into Service

asyncify/de-wrap rooms alias getters/some iterators

asyncify/de-wrap rooms directory getters/iterator

partially asyncify/de-wrap rooms lazy-loading

partially asyncify/de-wrap rooms metadata

asyncify/dewrap rooms outlier

asyncify/dewrap rooms pdu_metadata

dewrap/partially asyncify rooms read receipt

de-wrap rooms search service

de-wrap/partially asyncify rooms user service

partial de-wrap rooms state_compressor

de-wrap rooms state_cache

de-wrap room state et al

de-wrap rooms timeline service

additional users device/keys related

de-wrap/asyncify sender

asyncify services

refactor database to TryFuture/TryStream

refactor services for TryFuture/TryStream

asyncify api handlers

additional asyncification for admin module

abstract stream related; support reverse streams

additional stream conversions

asyncify state-res related

Signed-off-by: Jason Volk <jason@zemos.net>
This commit is contained in:
Jason Volk 2024-08-08 17:18:30 +00:00 committed by strawberry
parent 6001014078
commit 946ca364e0
203 changed files with 12202 additions and 10709 deletions

View file

@ -1,125 +0,0 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::Map;
use ruma::{api::client::error::ErrorKind, OwnedRoomAliasId, OwnedRoomId, OwnedUserId, RoomAliasId, RoomId, UserId};
use crate::{globals, Dep};
pub(super) struct Data {
alias_userid: Arc<Map>,
alias_roomid: Arc<Map>,
aliasid_alias: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
}
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
alias_userid: db["alias_userid"].clone(),
alias_roomid: db["alias_roomid"].clone(),
aliasid_alias: db["aliasid_alias"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
},
}
}
pub(super) fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
// Comes first as we don't want a stuck alias
self.alias_userid
.insert(alias.alias().as_bytes(), user_id.as_bytes())?;
self.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes())?;
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF);
aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
self.aliasid_alias.insert(&aliasid, alias.as_bytes())?;
Ok(())
}
pub(super) fn remove_alias(&self, alias: &RoomAliasId) -> Result<()> {
if let Some(room_id) = self.alias_roomid.get(alias.alias().as_bytes())? {
let mut prefix = room_id.to_vec();
prefix.push(0xFF);
for (key, _) in self.aliasid_alias.scan_prefix(prefix) {
self.aliasid_alias.remove(&key)?;
}
self.alias_roomid.remove(alias.alias().as_bytes())?;
self.alias_userid.remove(alias.alias().as_bytes())?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias does not exist or is invalid."));
}
Ok(())
}
pub(super) fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.alias_roomid
.get(alias.alias().as_bytes())?
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in alias_roomid is invalid."))
})
.transpose()
}
pub(super) fn who_created_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedUserId>> {
self.alias_userid
.get(alias.alias().as_bytes())?
.map(|bytes| {
UserId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("User ID in alias_userid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in alias_roomid is invalid."))
})
.transpose()
}
pub(super) fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a + Send> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.aliasid_alias.scan_prefix(prefix).map(|(_, bytes)| {
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid alias in aliasid_alias."))
}))
}
pub(super) fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
Box::new(
self.alias_roomid
.iter()
.map(|(room_alias_bytes, room_id_bytes)| {
let room_alias_localpart = utils::string_from_bytes(&room_alias_bytes)
.map_err(|_| Error::bad_database("Invalid alias bytes in aliasid_alias."))?;
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|_| Error::bad_database("Invalid room_id bytes in aliasid_alias."))?
.try_into()
.map_err(|_| Error::bad_database("Invalid room_id in aliasid_alias."))?;
Ok((room_id, room_alias_localpart))
}),
)
}
}

View file

@ -1,19 +1,23 @@
mod data;
mod remote;
use std::sync::Arc;
use conduit::{err, Error, Result};
use conduit::{
err,
utils::{stream::TryIgnore, ReadyExt},
Err, Error, Result,
};
use database::{Deserialized, Ignore, Interfix, Map};
use futures::{Stream, StreamExt};
use ruma::{
api::client::error::ErrorKind,
events::{
room::power_levels::{RoomPowerLevels, RoomPowerLevelsEventContent},
StateEventType,
},
OwnedRoomAliasId, OwnedRoomId, OwnedServerName, RoomAliasId, RoomId, RoomOrAliasId, UserId,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomAliasId, RoomId, RoomOrAliasId, UserId,
};
use self::data::Data;
use crate::{admin, appservice, appservice::RegistrationInfo, globals, rooms, sending, Dep};
pub struct Service {
@ -21,6 +25,12 @@ pub struct Service {
services: Services,
}
struct Data {
alias_userid: Arc<Map>,
alias_roomid: Arc<Map>,
aliasid_alias: Arc<Map>,
}
struct Services {
admin: Dep<admin::Service>,
appservice: Dep<appservice::Service>,
@ -32,7 +42,11 @@ struct Services {
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(&args),
db: Data {
alias_userid: args.db["alias_userid"].clone(),
alias_roomid: args.db["alias_roomid"].clone(),
aliasid_alias: args.db["aliasid_alias"].clone(),
},
services: Services {
admin: args.depend::<admin::Service>("admin"),
appservice: args.depend::<appservice::Service>("appservice"),
@ -50,25 +64,52 @@ impl Service {
#[tracing::instrument(skip(self))]
pub fn set_alias(&self, alias: &RoomAliasId, room_id: &RoomId, user_id: &UserId) -> Result<()> {
if alias == self.services.globals.admin_alias && user_id != self.services.globals.server_user {
Err(Error::BadRequest(
return Err(Error::BadRequest(
ErrorKind::forbidden(),
"Only the server user can set this alias",
))
} else {
self.db.set_alias(alias, room_id, user_id)
));
}
// Comes first as we don't want a stuck alias
self.db
.alias_userid
.insert(alias.alias().as_bytes(), user_id.as_bytes());
self.db
.alias_roomid
.insert(alias.alias().as_bytes(), room_id.as_bytes());
let mut aliasid = room_id.as_bytes().to_vec();
aliasid.push(0xFF);
aliasid.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
self.db.aliasid_alias.insert(&aliasid, alias.as_bytes());
Ok(())
}
#[tracing::instrument(skip(self))]
pub async fn remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result<()> {
if self.user_can_remove_alias(alias, user_id).await? {
self.db.remove_alias(alias)
} else {
Err(Error::BadRequest(
ErrorKind::forbidden(),
"User is not permitted to remove this alias.",
))
if !self.user_can_remove_alias(alias, user_id).await? {
return Err!(Request(Forbidden("User is not permitted to remove this alias.")));
}
let alias = alias.alias();
let Ok(room_id) = self.db.alias_roomid.qry(&alias).await else {
return Err!(Request(NotFound("Alias does not exist or is invalid.")));
};
let prefix = (&room_id, Interfix);
self.db
.aliasid_alias
.keys_prefix(&prefix)
.ignore_err()
.ready_for_each(|key: &[u8]| self.db.aliasid_alias.remove(&key))
.await;
self.db.alias_roomid.remove(alias.as_bytes());
self.db.alias_userid.remove(alias.as_bytes());
Ok(())
}
pub async fn resolve(&self, room: &RoomOrAliasId) -> Result<OwnedRoomId> {
@ -97,9 +138,9 @@ impl Service {
return self.remote_resolve(room_alias, servers).await;
}
let room_id: Option<OwnedRoomId> = match self.resolve_local_alias(room_alias)? {
Some(r) => Some(r),
None => self.resolve_appservice_alias(room_alias).await?,
let room_id: Option<OwnedRoomId> = match self.resolve_local_alias(room_alias).await {
Ok(r) => Some(r),
Err(_) => self.resolve_appservice_alias(room_alias).await?,
};
room_id.map_or_else(
@ -109,46 +150,54 @@ impl Service {
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
self.db.resolve_local_alias(alias)
pub async fn resolve_local_alias(&self, alias: &RoomAliasId) -> Result<OwnedRoomId> {
self.db.alias_roomid.qry(alias.alias()).await.deserialized()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn local_aliases_for_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedRoomAliasId>> + 'a + Send> {
self.db.local_aliases_for_room(room_id)
pub fn local_aliases_for_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &RoomAliasId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.aliasid_alias
.stream_prefix(&prefix)
.ignore_err()
.map(|((Ignore, Ignore), alias): ((Ignore, Ignore), &RoomAliasId)| alias)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn all_local_aliases<'a>(&'a self) -> Box<dyn Iterator<Item = Result<(OwnedRoomId, String)>> + 'a> {
self.db.all_local_aliases()
pub fn all_local_aliases<'a>(&'a self) -> impl Stream<Item = (&RoomId, &str)> + Send + 'a {
self.db
.alias_roomid
.stream()
.ignore_err()
.map(|(alias_localpart, room_id): (&str, &RoomId)| (room_id, alias_localpart))
}
async fn user_can_remove_alias(&self, alias: &RoomAliasId, user_id: &UserId) -> Result<bool> {
let Some(room_id) = self.resolve_local_alias(alias)? else {
return Err(Error::BadRequest(ErrorKind::NotFound, "Alias not found."));
};
let room_id = self
.resolve_local_alias(alias)
.await
.map_err(|_| err!(Request(NotFound("Alias not found."))))?;
let server_user = &self.services.globals.server_user;
// The creator of an alias can remove it
if self
.db
.who_created_alias(alias)?
.is_some_and(|user| user == user_id)
.who_created_alias(alias).await
.is_ok_and(|user| user == user_id)
// Server admins can remove any local alias
|| self.services.admin.user_is_admin(user_id).await?
|| self.services.admin.user_is_admin(user_id).await
// Always allow the server service account to remove the alias, since there may not be an admin room
|| server_user == user_id
{
Ok(true)
// Checking whether the user is able to change canonical aliases of the
// room
} else if let Some(event) =
self.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")?
} else if let Ok(event) = self
.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomPowerLevels, "")
.await
{
serde_json::from_str(event.content.get())
.map_err(|_| Error::bad_database("Invalid event content for m.room.power_levels"))
@ -157,10 +206,11 @@ impl Service {
})
// If there is no power levels event, only the room creator can change
// canonical aliases
} else if let Some(event) =
self.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")?
} else if let Ok(event) = self
.services
.state_accessor
.room_state_get(&room_id, &StateEventType::RoomCreate, "")
.await
{
Ok(event.sender == user_id)
} else {
@ -168,6 +218,10 @@ impl Service {
}
}
async fn who_created_alias(&self, alias: &RoomAliasId) -> Result<OwnedUserId> {
self.db.alias_userid.qry(alias.alias()).await.deserialized()
}
async fn resolve_appservice_alias(&self, room_alias: &RoomAliasId) -> Result<Option<OwnedRoomId>> {
use ruma::api::appservice::query::query_room_alias;
@ -185,10 +239,11 @@ impl Service {
.await,
Ok(Some(_opt_result))
) {
return Ok(Some(
self.resolve_local_alias(room_alias)?
.ok_or_else(|| err!(Request(NotFound("Room does not exist."))))?,
));
return self
.resolve_local_alias(room_alias)
.await
.map_err(|_| err!(Request(NotFound("Room does not exist."))))
.map(Some);
}
}

View file

@ -24,7 +24,7 @@ impl Data {
}
}
pub(super) fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
pub(super) async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
// Check RAM cache
if let Some(result) = self.auth_chain_cache.lock().unwrap().get_mut(key) {
return Ok(Some(Arc::clone(result)));
@ -33,17 +33,14 @@ impl Data {
// We only save auth chains for single events in the db
if key.len() == 1 {
// Check DB cache
let chain = self
.shorteventid_authchain
.get(&key[0].to_be_bytes())?
.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(utils::u64_from_u8)
.collect::<Arc<[u64]>>()
});
let chain = self.shorteventid_authchain.qry(&key[0]).await.map(|chain| {
chain
.chunks_exact(size_of::<u64>())
.map(utils::u64_from_u8)
.collect::<Arc<[u64]>>()
});
if let Some(chain) = chain {
if let Ok(chain) = chain {
// Cache in RAM
self.auth_chain_cache
.lock()
@ -66,7 +63,7 @@ impl Data {
.iter()
.flat_map(|s| s.to_be_bytes().to_vec())
.collect::<Vec<u8>>(),
)?;
);
}
// Cache in RAM

View file

@ -5,7 +5,8 @@ use std::{
sync::Arc,
};
use conduit::{debug, error, trace, validated, warn, Err, Result};
use conduit::{debug, debug_error, trace, utils::IterStream, validated, warn, Err, Result};
use futures::{FutureExt, Stream, StreamExt};
use ruma::{EventId, RoomId};
use self::data::Data;
@ -38,7 +39,7 @@ impl crate::Service for Service {
impl Service {
pub async fn event_ids_iter<'a>(
&'a self, room_id: &RoomId, starting_events_: Vec<Arc<EventId>>,
) -> Result<impl Iterator<Item = Arc<EventId>> + 'a> {
) -> Result<impl Stream<Item = Arc<EventId>> + Send + 'a> {
let mut starting_events: Vec<&EventId> = Vec::with_capacity(starting_events_.len());
for starting_event in &starting_events_ {
starting_events.push(starting_event);
@ -48,7 +49,13 @@ impl Service {
.get_auth_chain(room_id, &starting_events)
.await?
.into_iter()
.filter_map(move |sid| self.services.short.get_eventid_from_short(sid).ok()))
.stream()
.filter_map(|sid| {
self.services
.short
.get_eventid_from_short(sid)
.map(Result::ok)
}))
}
#[tracing::instrument(skip_all, name = "auth_chain")]
@ -61,7 +68,8 @@ impl Service {
for (i, &short) in self
.services
.short
.multi_get_or_create_shorteventid(starting_events)?
.multi_get_or_create_shorteventid(starting_events)
.await
.iter()
.enumerate()
{
@ -85,7 +93,7 @@ impl Service {
}
let chunk_key: Vec<u64> = chunk.iter().map(|(short, _)| short).copied().collect();
if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key)? {
if let Some(cached) = self.get_cached_eventid_authchain(&chunk_key).await? {
trace!("Found cache entry for whole chunk");
full_auth_chain.extend(cached.iter().copied());
hits = hits.saturating_add(1);
@ -96,12 +104,12 @@ impl Service {
let mut misses2: usize = 0;
let mut chunk_cache = Vec::with_capacity(chunk.len());
for (sevent_id, event_id) in chunk {
if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id])? {
if let Some(cached) = self.get_cached_eventid_authchain(&[sevent_id]).await? {
trace!(?event_id, "Found cache entry for event");
chunk_cache.extend(cached.iter().copied());
hits2 = hits2.saturating_add(1);
} else {
let auth_chain = self.get_auth_chain_inner(room_id, event_id)?;
let auth_chain = self.get_auth_chain_inner(room_id, event_id).await?;
self.cache_auth_chain(vec![sevent_id], &auth_chain)?;
chunk_cache.extend(auth_chain.iter());
misses2 = misses2.saturating_add(1);
@ -143,15 +151,16 @@ impl Service {
}
#[tracing::instrument(skip(self, room_id))]
fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> {
async fn get_auth_chain_inner(&self, room_id: &RoomId, event_id: &EventId) -> Result<HashSet<u64>> {
let mut todo = vec![Arc::from(event_id)];
let mut found = HashSet::new();
while let Some(event_id) = todo.pop() {
trace!(?event_id, "processing auth event");
match self.services.timeline.get_pdu(&event_id) {
Ok(Some(pdu)) => {
match self.services.timeline.get_pdu(&event_id).await {
Err(e) => debug_error!(?event_id, ?e, "Could not find pdu mentioned in auth events"),
Ok(pdu) => {
if pdu.room_id != room_id {
return Err!(Request(Forbidden(
"auth event {event_id:?} for incorrect room {} which is not {}",
@ -160,7 +169,11 @@ impl Service {
)));
}
for auth_event in &pdu.auth_events {
let sauthevent = self.services.short.get_or_create_shorteventid(auth_event)?;
let sauthevent = self
.services
.short
.get_or_create_shorteventid(auth_event)
.await;
if found.insert(sauthevent) {
trace!(?event_id, ?auth_event, "adding auth event to processing queue");
@ -168,20 +181,14 @@ impl Service {
}
}
},
Ok(None) => {
warn!(?event_id, "Could not find pdu mentioned in auth events");
},
Err(error) => {
error!(?event_id, ?error, "Could not load event in auth chain");
},
}
}
Ok(found)
}
pub fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
self.db.get_cached_eventid_authchain(key)
pub async fn get_cached_eventid_authchain(&self, key: &[u64]) -> Result<Option<Arc<[u64]>>> {
self.db.get_cached_eventid_authchain(key).await
}
#[tracing::instrument(skip(self), level = "debug")]

View file

@ -1,39 +0,0 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::{Database, Map};
use ruma::{OwnedRoomId, RoomId};
pub(super) struct Data {
publicroomids: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
publicroomids: db["publicroomids"].clone(),
}
}
pub(super) fn set_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.insert(room_id.as_bytes(), &[])
}
pub(super) fn set_not_public(&self, room_id: &RoomId) -> Result<()> {
self.publicroomids.remove(room_id.as_bytes())
}
pub(super) fn is_public_room(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.publicroomids.get(room_id.as_bytes())?.is_some())
}
pub(super) fn public_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.publicroomids.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid."))
}))
}
}

View file

@ -1,36 +1,44 @@
mod data;
use std::sync::Arc;
use conduit::Result;
use ruma::{OwnedRoomId, RoomId};
use self::data::Data;
use conduit::{implement, utils::stream::TryIgnore, Result};
use database::{Ignore, Map};
use futures::{Stream, StreamExt};
use ruma::RoomId;
pub struct Service {
db: Data,
}
struct Data {
publicroomids: Arc<Map>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data {
publicroomids: args.db["publicroomids"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
#[tracing::instrument(skip(self), level = "debug")]
pub fn set_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_public(room_id) }
#[implement(Service)]
pub fn set_public(&self, room_id: &RoomId) { self.db.publicroomids.insert(room_id.as_bytes(), &[]); }
#[tracing::instrument(skip(self), level = "debug")]
pub fn set_not_public(&self, room_id: &RoomId) -> Result<()> { self.db.set_not_public(room_id) }
#[implement(Service)]
pub fn set_not_public(&self, room_id: &RoomId) { self.db.publicroomids.remove(room_id.as_bytes()); }
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_public_room(&self, room_id: &RoomId) -> Result<bool> { self.db.is_public_room(room_id) }
#[implement(Service)]
pub async fn is_public_room(&self, room_id: &RoomId) -> bool { self.db.publicroomids.qry(room_id).await.is_ok() }
#[tracing::instrument(skip(self), level = "debug")]
pub fn public_rooms(&self) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ { self.db.public_rooms() }
#[implement(Service)]
pub fn public_rooms(&self) -> impl Stream<Item = &RoomId> + Send {
self.db
.publicroomids
.keys()
.ignore_err()
.map(|(room_id, _): (&RoomId, Ignore)| room_id)
}

File diff suppressed because it is too large Load diff

View file

@ -3,7 +3,9 @@ use ruma::{CanonicalJsonObject, OwnedEventId, OwnedRoomId, RoomId};
use serde_json::value::RawValue as RawJsonValue;
impl super::Service {
pub fn parse_incoming_pdu(&self, pdu: &RawJsonValue) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
pub async fn parse_incoming_pdu(
&self, pdu: &RawJsonValue,
) -> Result<(OwnedEventId, CanonicalJsonObject, OwnedRoomId)> {
let value: CanonicalJsonObject = serde_json::from_str(pdu.get()).map_err(|e| {
debug_warn!("Error parsing incoming event {pdu:#?}");
err!(BadServerResponse("Error parsing incoming event {e:?}"))
@ -14,7 +16,7 @@ impl super::Service {
.and_then(|id| RoomId::parse(id.as_str()?).ok())
.ok_or(err!(Request(InvalidParam("Invalid room id in pdu"))))?;
let Ok(room_version_id) = self.services.state.get_room_version(&room_id) else {
let Ok(room_version_id) = self.services.state.get_room_version(&room_id).await else {
return Err!("Server is not in room {room_id}");
};

View file

@ -1,65 +0,0 @@
use std::sync::Arc;
use conduit::Result;
use database::{Database, Map};
use ruma::{DeviceId, RoomId, UserId};
pub(super) struct Data {
lazyloadedids: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
lazyloadedids: db["lazyloadedids"].clone(),
}
}
pub(super) fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(device_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
key.push(0xFF);
key.extend_from_slice(ll_user.as_bytes());
Ok(self.lazyloadedids.get(&key)?.is_some())
}
pub(super) fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId,
confirmed_user_ids: &mut dyn Iterator<Item = &UserId>,
) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for ll_id in confirmed_user_ids {
let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes());
self.lazyloadedids.insert(&key, &[])?;
}
Ok(())
}
pub(super) fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for (key, _) in self.lazyloadedids.scan_prefix(prefix) {
self.lazyloadedids.remove(&key)?;
}
Ok(())
}
}

View file

@ -1,21 +1,26 @@
mod data;
use std::{
collections::{HashMap, HashSet},
fmt::Write,
sync::{Arc, Mutex},
};
use conduit::{PduCount, Result};
use conduit::{
implement,
utils::{stream::TryIgnore, ReadyExt},
PduCount, Result,
};
use database::{Interfix, Map};
use ruma::{DeviceId, OwnedDeviceId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use self::data::Data;
pub struct Service {
pub lazy_load_waiting: Mutex<LazyLoadWaiting>,
lazy_load_waiting: Mutex<LazyLoadWaiting>,
db: Data,
}
struct Data {
lazyloadedids: Arc<Map>,
}
type LazyLoadWaiting = HashMap<LazyLoadWaitingKey, LazyLoadWaitingVal>;
type LazyLoadWaitingKey = (OwnedUserId, OwnedDeviceId, OwnedRoomId, PduCount);
type LazyLoadWaitingVal = HashSet<OwnedUserId>;
@ -23,8 +28,10 @@ type LazyLoadWaitingVal = HashSet<OwnedUserId>;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
lazy_load_waiting: Mutex::new(HashMap::new()),
db: Data::new(args.db),
lazy_load_waiting: LazyLoadWaiting::new().into(),
db: Data {
lazyloadedids: args.db["lazyloadedids"].clone(),
},
}))
}
@ -40,47 +47,60 @@ impl crate::Service for Service {
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> Result<bool> {
self.db
.lazy_load_was_sent_before(user_id, device_id, room_id, ll_user)
}
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
#[inline]
pub async fn lazy_load_was_sent_before(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, ll_user: &UserId,
) -> bool {
let key = (user_id, device_id, room_id, ll_user);
self.db.lazyloadedids.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn lazy_load_mark_sent(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>,
count: PduCount,
) {
self.lazy_load_waiting
.lock()
.expect("locked")
.insert((user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count), lazy_load);
}
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_mark_sent(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, lazy_load: HashSet<OwnedUserId>, count: PduCount,
) {
let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), count);
#[tracing::instrument(skip(self), level = "debug")]
pub async fn lazy_load_confirm_delivery(
&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount,
) -> Result<()> {
if let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&(
user_id.to_owned(),
device_id.to_owned(),
room_id.to_owned(),
since,
)) {
self.db
.lazy_load_confirm_delivery(user_id, device_id, room_id, &mut user_ids.iter().map(|u| &**u))?;
} else {
// Ignore
}
self.lazy_load_waiting
.lock()
.expect("locked")
.insert(key, lazy_load);
}
Ok(())
}
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_confirm_delivery(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId, since: PduCount) {
let key = (user_id.to_owned(), device_id.to_owned(), room_id.to_owned(), since);
#[tracing::instrument(skip(self), level = "debug")]
pub fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) -> Result<()> {
self.db.lazy_load_reset(user_id, device_id, room_id)
let Some(user_ids) = self.lazy_load_waiting.lock().expect("locked").remove(&key) else {
return;
};
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
prefix.extend_from_slice(device_id.as_bytes());
prefix.push(0xFF);
prefix.extend_from_slice(room_id.as_bytes());
prefix.push(0xFF);
for ll_id in &user_ids {
let mut key = prefix.clone();
key.extend_from_slice(ll_id.as_bytes());
self.db.lazyloadedids.insert(&key, &[]);
}
}
#[implement(Service)]
#[tracing::instrument(skip(self), level = "debug")]
pub async fn lazy_load_reset(&self, user_id: &UserId, device_id: &DeviceId, room_id: &RoomId) {
let prefix = (user_id, device_id, room_id, Interfix);
self.db
.lazyloadedids
.keys_raw_prefix(&prefix)
.ignore_err()
.ready_for_each(|key| self.db.lazyloadedids.remove(key))
.await;
}

View file

@ -1,110 +0,0 @@
use std::sync::Arc;
use conduit::{error, utils, Error, Result};
use database::Map;
use ruma::{OwnedRoomId, RoomId};
use crate::{rooms, Dep};
pub(super) struct Data {
disabledroomids: Arc<Map>,
bannedroomids: Arc<Map>,
roomid_shortroomid: Arc<Map>,
pduid_pdu: Arc<Map>,
services: Services,
}
struct Services {
short: Dep<rooms::short::Service>,
}
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
Self {
disabledroomids: db["disabledroomids"].clone(),
bannedroomids: db["bannedroomids"].clone(),
roomid_shortroomid: db["roomid_shortroomid"].clone(),
pduid_pdu: db["pduid_pdu"].clone(),
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
},
}
}
pub(super) fn exists(&self, room_id: &RoomId) -> Result<bool> {
let prefix = match self.services.short.get_shortroomid(room_id)? {
Some(b) => b.to_be_bytes().to_vec(),
None => return Ok(false),
};
// Look for PDUs in that room.
Ok(self
.pduid_pdu
.iter_from(&prefix, false)
.next()
.filter(|(k, _)| k.starts_with(&prefix))
.is_some())
}
pub(super) fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.roomid_shortroomid.iter().map(|(bytes, _)| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Room ID in publicroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in roomid_shortroomid is invalid."))
}))
}
#[inline]
pub(super) fn is_disabled(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.disabledroomids.get(room_id.as_bytes())?.is_some())
}
#[inline]
pub(super) fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
if disabled {
self.disabledroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.disabledroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
#[inline]
pub(super) fn is_banned(&self, room_id: &RoomId) -> Result<bool> {
Ok(self.bannedroomids.get(room_id.as_bytes())?.is_some())
}
#[inline]
pub(super) fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> {
if banned {
self.bannedroomids.insert(room_id.as_bytes(), &[])?;
} else {
self.bannedroomids.remove(room_id.as_bytes())?;
}
Ok(())
}
pub(super) fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
Box::new(self.bannedroomids.iter().map(
|(room_id_bytes, _ /* non-banned rooms should not be in this table */)| {
let room_id = utils::string_from_bytes(&room_id_bytes)
.map_err(|e| {
error!("Invalid room_id bytes in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids.")
})?
.try_into()
.map_err(|e| {
error!("Invalid room_id in bannedroomids: {e}");
Error::bad_database("Invalid room_id in bannedroomids")
})?;
Ok(room_id)
},
))
}
}

View file

@ -1,51 +1,92 @@
mod data;
use std::sync::Arc;
use conduit::Result;
use ruma::{OwnedRoomId, RoomId};
use conduit::{implement, utils::stream::TryIgnore, Result};
use database::Map;
use futures::{Stream, StreamExt};
use ruma::RoomId;
use self::data::Data;
use crate::{rooms, Dep};
pub struct Service {
db: Data,
services: Services,
}
struct Data {
disabledroomids: Arc<Map>,
bannedroomids: Arc<Map>,
roomid_shortroomid: Arc<Map>,
pduid_pdu: Arc<Map>,
}
struct Services {
short: Dep<rooms::short::Service>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(&args),
db: Data {
disabledroomids: args.db["disabledroomids"].clone(),
bannedroomids: args.db["bannedroomids"].clone(),
roomid_shortroomid: args.db["roomid_shortroomid"].clone(),
pduid_pdu: args.db["pduid_pdu"].clone(),
},
services: Services {
short: args.depend::<rooms::short::Service>("rooms::short"),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Checks if a room exists.
#[inline]
pub fn exists(&self, room_id: &RoomId) -> Result<bool> { self.db.exists(room_id) }
#[implement(Service)]
pub async fn exists(&self, room_id: &RoomId) -> bool {
let Ok(prefix) = self.services.short.get_shortroomid(room_id).await else {
return false;
};
#[must_use]
pub fn iter_ids<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> { self.db.iter_ids() }
// Look for PDUs in that room.
self.db
.pduid_pdu
.keys_raw_prefix(&prefix)
.ignore_err()
.next()
.await
.is_some()
}
#[inline]
pub fn is_disabled(&self, room_id: &RoomId) -> Result<bool> { self.db.is_disabled(room_id) }
#[implement(Service)]
pub fn iter_ids(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.roomid_shortroomid.keys().ignore_err() }
#[inline]
pub fn disable_room(&self, room_id: &RoomId, disabled: bool) -> Result<()> {
self.db.disable_room(room_id, disabled)
}
#[inline]
pub fn is_banned(&self, room_id: &RoomId) -> Result<bool> { self.db.is_banned(room_id) }
#[inline]
pub fn ban_room(&self, room_id: &RoomId, banned: bool) -> Result<()> { self.db.ban_room(room_id, banned) }
#[inline]
#[must_use]
pub fn list_banned_rooms<'a>(&'a self) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
self.db.list_banned_rooms()
#[implement(Service)]
#[inline]
pub fn disable_room(&self, room_id: &RoomId, disabled: bool) {
if disabled {
self.db.disabledroomids.insert(room_id.as_bytes(), &[]);
} else {
self.db.disabledroomids.remove(room_id.as_bytes());
}
}
#[implement(Service)]
#[inline]
pub fn ban_room(&self, room_id: &RoomId, banned: bool) {
if banned {
self.db.bannedroomids.insert(room_id.as_bytes(), &[]);
} else {
self.db.bannedroomids.remove(room_id.as_bytes());
}
}
#[implement(Service)]
pub fn list_banned_rooms(&self) -> impl Stream<Item = &RoomId> + Send + '_ { self.db.bannedroomids.keys().ignore_err() }
#[implement(Service)]
#[inline]
pub async fn is_disabled(&self, room_id: &RoomId) -> bool { self.db.disabledroomids.qry(room_id).await.is_ok() }
#[implement(Service)]
#[inline]
pub async fn is_banned(&self, room_id: &RoomId) -> bool { self.db.bannedroomids.qry(room_id).await.is_ok() }

View file

@ -1,42 +0,0 @@
use std::sync::Arc;
use conduit::{Error, Result};
use database::{Database, Map};
use ruma::{CanonicalJsonObject, EventId};
use crate::PduEvent;
pub(super) struct Data {
eventid_outlierpdu: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
eventid_outlierpdu: db["eventid_outlierpdu"].clone(),
}
}
pub(super) fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
pub(super) fn get_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map_or(Ok(None), |pdu| {
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))
})
}
pub(super) fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.eventid_outlierpdu.insert(
event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
)
}
}

View file

@ -1,9 +1,7 @@
mod data;
use std::sync::Arc;
use conduit::Result;
use data::Data;
use conduit::{implement, Result};
use database::{Deserialized, Map};
use ruma::{CanonicalJsonObject, EventId};
use crate::PduEvent;
@ -12,31 +10,48 @@ pub struct Service {
db: Data,
}
struct Data {
eventid_outlierpdu: Arc<Map>,
}
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
db: Data::new(args.db),
db: Data {
eventid_outlierpdu: args.db["eventid_outlierpdu"].clone(),
},
}))
}
fn name(&self) -> &str { crate::service::make_name(std::module_path!()) }
}
impl Service {
/// Returns the pdu from the outlier tree.
pub fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.db.get_outlier_pdu_json(event_id)
}
/// Returns the pdu from the outlier tree.
///
/// TODO: use this?
#[allow(dead_code)]
pub fn get_pdu_outlier(&self, event_id: &EventId) -> Result<Option<PduEvent>> { self.db.get_outlier_pdu(event_id) }
/// Append the PDU as an outlier.
#[tracing::instrument(skip(self, pdu), level = "debug")]
pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) -> Result<()> {
self.db.add_pdu_outlier(event_id, pdu)
}
/// Returns the pdu from the outlier tree.
#[implement(Service)]
pub async fn get_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
self.db
.eventid_outlierpdu
.qry(event_id)
.await
.deserialized_json()
}
/// Returns the pdu from the outlier tree.
#[implement(Service)]
pub async fn get_pdu_outlier(&self, event_id: &EventId) -> Result<PduEvent> {
self.db
.eventid_outlierpdu
.qry(event_id)
.await
.deserialized_json()
}
/// Append the PDU as an outlier.
#[implement(Service)]
#[tracing::instrument(skip(self, pdu), level = "debug")]
pub fn add_pdu_outlier(&self, event_id: &EventId, pdu: &CanonicalJsonObject) {
self.db.eventid_outlierpdu.insert(
event_id.as_bytes(),
&serde_json::to_vec(&pdu).expect("CanonicalJsonObject is valid"),
);
}

View file

@ -1,7 +1,13 @@
use std::{mem::size_of, sync::Arc};
use conduit::{utils, Error, PduCount, PduEvent, Result};
use conduit::{
result::LogErr,
utils,
utils::{stream::TryIgnore, ReadyExt},
PduCount, PduEvent,
};
use database::Map;
use futures::{Stream, StreamExt};
use ruma::{EventId, RoomId, UserId};
use crate::{rooms, Dep};
@ -17,8 +23,7 @@ struct Services {
timeline: Dep<rooms::timeline::Service>,
}
type PdusIterItem = Result<(PduCount, PduEvent)>;
type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>;
pub(super) type PdusIterItem = (PduCount, PduEvent);
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
@ -33,19 +38,17 @@ impl Data {
}
}
pub(super) fn add_relation(&self, from: u64, to: u64) -> Result<()> {
pub(super) fn add_relation(&self, from: u64, to: u64) {
let mut key = to.to_be_bytes().to_vec();
key.extend_from_slice(&from.to_be_bytes());
self.tofrom_relation.insert(&key, &[])?;
Ok(())
self.tofrom_relation.insert(&key, &[]);
}
pub(super) fn relations_until<'a>(
&'a self, user_id: &'a UserId, shortroomid: u64, target: u64, until: PduCount,
) -> Result<PdusIterator<'a>> {
) -> impl Stream<Item = PdusIterItem> + Send + 'a + '_ {
let prefix = target.to_be_bytes().to_vec();
let mut current = prefix.clone();
let count_raw = match until {
PduCount::Normal(x) => x.saturating_sub(1),
PduCount::Backfilled(x) => {
@ -55,53 +58,42 @@ impl Data {
};
current.extend_from_slice(&count_raw.to_be_bytes());
Ok(Box::new(
self.tofrom_relation
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(tofrom, _data)| {
let from = utils::u64_from_bytes(&tofrom[(size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid count in tofrom_relation."))?;
self.tofrom_relation
.rev_raw_keys_from(&current)
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix))
.map(|to_from| utils::u64_from_u8(&to_from[(size_of::<u64>())..]))
.filter_map(move |from| async move {
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
let mut pdu = self.services.timeline.get_pdu_from_id(&pduid).await.ok()?;
let mut pduid = shortroomid.to_be_bytes().to_vec();
pduid.extend_from_slice(&from.to_be_bytes());
if pdu.sender != user_id {
pdu.remove_transaction_id().log_err().ok();
}
let mut pdu = self
.services
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Pdu in tofrom_relation is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((PduCount::Normal(from), pdu))
}),
))
Some((PduCount::Normal(from), pdu))
})
}
pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
pub(super) fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) {
for prev in event_ids {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(prev.as_bytes());
self.referencedevents.insert(&key, &[])?;
self.referencedevents.insert(&key, &[]);
}
Ok(())
}
pub(super) fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let mut key = room_id.as_bytes().to_vec();
key.extend_from_slice(event_id.as_bytes());
Ok(self.referencedevents.get(&key)?.is_some())
pub(super) async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool {
let key = (room_id, event_id);
self.referencedevents.qry(&key).await.is_ok()
}
pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> {
self.softfailedeventids.insert(event_id.as_bytes(), &[])
pub(super) fn mark_event_soft_failed(&self, event_id: &EventId) {
self.softfailedeventids.insert(event_id.as_bytes(), &[]);
}
pub(super) fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> {
self.softfailedeventids
.get(event_id.as_bytes())
.map(|o| o.is_some())
pub(super) async fn is_event_soft_failed(&self, event_id: &EventId) -> bool {
self.softfailedeventids.qry(event_id).await.is_ok()
}
}

View file

@ -1,8 +1,8 @@
mod data;
use std::sync::Arc;
use conduit::{PduCount, PduEvent, Result};
use conduit::{utils::stream::IterStream, PduCount, Result};
use futures::StreamExt;
use ruma::{
api::{client::relations::get_relating_events, Direction},
events::{relation::RelationType, TimelineEventType},
@ -10,7 +10,7 @@ use ruma::{
};
use serde::Deserialize;
use self::data::Data;
use self::data::{Data, PdusIterItem};
use crate::{rooms, Dep};
pub struct Service {
@ -51,21 +51,19 @@ impl crate::Service for Service {
impl Service {
#[tracing::instrument(skip(self, from, to), level = "debug")]
pub fn add_relation(&self, from: PduCount, to: PduCount) -> Result<()> {
pub fn add_relation(&self, from: PduCount, to: PduCount) {
match (from, to) {
(PduCount::Normal(f), PduCount::Normal(t)) => self.db.add_relation(f, t),
_ => {
// TODO: Relations with backfilled pdus
Ok(())
},
}
}
#[allow(clippy::too_many_arguments)]
pub fn paginate_relations_with_filter(
&self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: &Option<TimelineEventType>,
filter_rel_type: &Option<RelationType>, from: &Option<String>, to: &Option<String>, limit: &Option<UInt>,
pub async fn paginate_relations_with_filter(
&self, sender_user: &UserId, room_id: &RoomId, target: &EventId, filter_event_type: Option<TimelineEventType>,
filter_rel_type: Option<RelationType>, from: Option<&String>, to: Option<&String>, limit: Option<UInt>,
recurse: bool, dir: Direction,
) -> Result<get_relating_events::v1::Response> {
let from = match from {
@ -76,7 +74,7 @@ impl Service {
},
};
let to = to.as_ref().and_then(|t| PduCount::try_from_string(t).ok());
let to = to.and_then(|t| PduCount::try_from_string(t).ok());
// Use limit or else 10, with maximum 100
let limit = limit
@ -92,30 +90,32 @@ impl Service {
1
};
let relations_until = &self.relations_until(sender_user, room_id, target, from, depth)?;
let events: Vec<_> = relations_until // TODO: should be relations_after
.iter()
.filter(|(_, pdu)| {
filter_event_type.as_ref().map_or(true, |t| &pdu.kind == t)
&& if let Ok(content) =
serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get())
{
filter_rel_type
.as_ref()
.map_or(true, |r| &content.relates_to.rel_type == r)
} else {
false
}
})
.take(limit)
.filter(|(_, pdu)| {
self.services
.state_accessor
.user_can_see_event(sender_user, room_id, &pdu.event_id)
.unwrap_or(false)
})
.take_while(|(k, _)| Some(k) != to.as_ref()) // Stop at `to`
.collect();
let relations_until: Vec<PdusIterItem> = self
.relations_until(sender_user, room_id, target, from, depth)
.await?;
// TODO: should be relations_after
let events: Vec<_> = relations_until
.into_iter()
.filter(move |(_, pdu): &PdusIterItem| {
if !filter_event_type.as_ref().map_or(true, |t| pdu.kind == *t) {
return false;
}
let Ok(content) = serde_json::from_str::<ExtractRelatesToEventId>(pdu.content.get()) else {
return false;
};
filter_rel_type
.as_ref()
.map_or(true, |r| *r == content.relates_to.rel_type)
})
.take(limit)
.take_while(|(k, _)| Some(*k) != to)
.stream()
.filter_map(|item| self.visibility_filter(sender_user, item))
.collect()
.await;
let next_token = events.last().map(|(count, _)| count).copied();
@ -125,9 +125,9 @@ impl Service {
.map(|(_, pdu)| pdu.to_message_like_event())
.collect(),
Direction::Backward => events
.into_iter()
.rev() // relations are always most recent first
.map(|(_, pdu)| pdu.to_message_like_event())
.into_iter()
.rev() // relations are always most recent first
.map(|(_, pdu)| pdu.to_message_like_event())
.collect(),
};
@ -135,68 +135,85 @@ impl Service {
chunk: events_chunk,
next_batch: next_token.map(|t| t.stringify()),
prev_batch: Some(from.stringify()),
recursion_depth: if recurse {
Some(depth.into())
} else {
None
},
recursion_depth: recurse.then_some(depth.into()),
})
}
pub fn relations_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, target: &'a EventId, until: PduCount, max_depth: u8,
) -> Result<Vec<(PduCount, PduEvent)>> {
let room_id = self.services.short.get_or_create_shortroomid(room_id)?;
#[allow(unknown_lints)]
#[allow(clippy::manual_unwrap_or_default)]
let target = match self.services.timeline.get_pdu_count(target)? {
Some(PduCount::Normal(c)) => c,
async fn visibility_filter(&self, sender_user: &UserId, item: PdusIterItem) -> Option<PdusIterItem> {
let (_, pdu) = &item;
self.services
.state_accessor
.user_can_see_event(sender_user, &pdu.room_id, &pdu.event_id)
.await
.then_some(item)
}
pub async fn relations_until(
&self, user_id: &UserId, room_id: &RoomId, target: &EventId, until: PduCount, max_depth: u8,
) -> Result<Vec<PdusIterItem>> {
let room_id = self.services.short.get_or_create_shortroomid(room_id).await;
let target = match self.services.timeline.get_pdu_count(target).await {
Ok(PduCount::Normal(c)) => c,
// TODO: Support backfilled relations
_ => 0, // This will result in an empty iterator
};
self.db
let mut pdus: Vec<PdusIterItem> = self
.db
.relations_until(user_id, room_id, target, until)
.map(|mut relations| {
let mut pdus: Vec<_> = (*relations).into_iter().filter_map(Result::ok).collect();
let mut stack: Vec<_> = pdus.clone().iter().map(|pdu| (pdu.to_owned(), 1)).collect();
.collect()
.await;
while let Some(stack_pdu) = stack.pop() {
let target = match stack_pdu.0 .0 {
PduCount::Normal(c) => c,
// TODO: Support backfilled relations
PduCount::Backfilled(_) => 0, // This will result in an empty iterator
};
let mut stack: Vec<_> = pdus.clone().into_iter().map(|pdu| (pdu, 1)).collect();
if let Ok(relations) = self.db.relations_until(user_id, room_id, target, until) {
for relation in relations.flatten() {
if stack_pdu.1 < max_depth {
stack.push((relation.clone(), stack_pdu.1.saturating_add(1)));
}
while let Some(stack_pdu) = stack.pop() {
let target = match stack_pdu.0 .0 {
PduCount::Normal(c) => c,
// TODO: Support backfilled relations
PduCount::Backfilled(_) => 0, // This will result in an empty iterator
};
pdus.push(relation);
}
}
let relations: Vec<PdusIterItem> = self
.db
.relations_until(user_id, room_id, target, until)
.collect()
.await;
for relation in relations {
if stack_pdu.1 < max_depth {
stack.push((relation.clone(), stack_pdu.1.saturating_add(1)));
}
pdus.sort_by(|a, b| a.0.cmp(&b.0));
pdus
})
pdus.push(relation);
}
}
pdus.sort_by(|a, b| a.0.cmp(&b.0));
Ok(pdus)
}
#[inline]
#[tracing::instrument(skip_all, level = "debug")]
pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) -> Result<()> {
self.db.mark_as_referenced(room_id, event_ids)
pub fn mark_as_referenced(&self, room_id: &RoomId, event_ids: &[Arc<EventId>]) {
self.db.mark_as_referenced(room_id, event_ids);
}
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
self.db.is_event_referenced(room_id, event_id)
pub async fn is_event_referenced(&self, room_id: &RoomId, event_id: &EventId) -> bool {
self.db.is_event_referenced(room_id, event_id).await
}
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_event_soft_failed(&self, event_id: &EventId) -> Result<()> { self.db.mark_event_soft_failed(event_id) }
pub fn mark_event_soft_failed(&self, event_id: &EventId) { self.db.mark_event_soft_failed(event_id) }
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_event_soft_failed(&self, event_id: &EventId) -> Result<bool> { self.db.is_event_soft_failed(event_id) }
pub async fn is_event_soft_failed(&self, event_id: &EventId) -> bool {
self.db.is_event_soft_failed(event_id).await
}
}

View file

@ -1,10 +1,18 @@
use std::{mem::size_of, sync::Arc};
use conduit::{utils, Error, Result};
use database::Map;
use ruma::{events::receipt::ReceiptEvent, serde::Raw, CanonicalJsonObject, RoomId, UserId};
use conduit::{
utils,
utils::{stream::TryIgnore, ReadyExt},
Error, Result,
};
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{
events::{receipt::ReceiptEvent, AnySyncEphemeralRoomEvent},
serde::Raw,
CanonicalJsonObject, OwnedUserId, RoomId, UserId,
};
use super::AnySyncEphemeralRoomEventIter;
use crate::{globals, Dep};
pub(super) struct Data {
@ -18,6 +26,8 @@ struct Services {
globals: Dep<globals::Service>,
}
pub(super) type ReceiptItem = (OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>);
impl Data {
pub(super) fn new(args: &crate::Args<'_>) -> Self {
let db = &args.db;
@ -31,7 +41,9 @@ impl Data {
}
}
pub(super) fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> {
pub(super) async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) {
type KeyVal<'a> = (&'a RoomId, u64, &'a UserId);
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
@ -39,108 +51,90 @@ impl Data {
last_possible_key.extend_from_slice(&u64::MAX.to_be_bytes());
// Remove old entry
if let Some((old, _)) = self
.readreceiptid_readreceipt
.iter_from(&last_possible_key, true)
.take_while(|(key, _)| key.starts_with(&prefix))
.find(|(key, _)| {
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element")
== user_id.as_bytes()
}) {
// This is the old room_latest
self.readreceiptid_readreceipt.remove(&old)?;
}
self.readreceiptid_readreceipt
.rev_keys_from_raw(&last_possible_key)
.ignore_err()
.ready_take_while(|(r, ..): &KeyVal<'_>| *r == room_id)
.ready_filter_map(|(r, c, u): KeyVal<'_>| (u == user_id).then_some((r, c, u)))
.ready_for_each(|old: KeyVal<'_>| {
// This is the old room_latest
self.readreceiptid_readreceipt.del(&old);
})
.await;
let mut room_latest_id = prefix;
room_latest_id.extend_from_slice(&self.services.globals.next_count()?.to_be_bytes());
room_latest_id.extend_from_slice(&self.services.globals.next_count().unwrap().to_be_bytes());
room_latest_id.push(0xFF);
room_latest_id.extend_from_slice(user_id.as_bytes());
self.readreceiptid_readreceipt.insert(
&room_latest_id,
&serde_json::to_vec(event).expect("EduEvent::to_string always works"),
)?;
Ok(())
);
}
pub(super) fn readreceipts_since<'a>(&'a self, room_id: &RoomId, since: u64) -> AnySyncEphemeralRoomEventIter<'a> {
pub(super) fn readreceipts_since<'a>(
&'a self, room_id: &'a RoomId, since: u64,
) -> impl Stream<Item = ReceiptItem> + Send + 'a {
let after_since = since.saturating_add(1); // +1 so we don't send the event at since
let first_possible_edu = (room_id, after_since);
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
let prefix2 = prefix.clone();
let mut first_possible_edu = prefix.clone();
first_possible_edu.extend_from_slice(&(since.saturating_add(1)).to_be_bytes()); // +1 so we don't send the event at since
self.readreceiptid_readreceipt
.stream_raw_from(&first_possible_edu)
.ignore_err()
.ready_take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| {
let count_offset = prefix.len().saturating_add(size_of::<u64>());
let user_id_offset = count_offset.saturating_add(1);
Box::new(
self.readreceiptid_readreceipt
.iter_from(&first_possible_edu, false)
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(k, v)| {
let count_offset = prefix.len().saturating_add(size_of::<u64>());
let count = utils::u64_from_bytes(&k[prefix.len()..count_offset])
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
let user_id_offset = count_offset.saturating_add(1);
let user_id = UserId::parse(
utils::string_from_bytes(&k[user_id_offset..])
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?,
)
let count = utils::u64_from_bytes(&k[prefix.len()..count_offset])
.map_err(|_| Error::bad_database("Invalid readreceiptid count in db."))?;
let user_id_str = utils::string_from_bytes(&k[user_id_offset..])
.map_err(|_| Error::bad_database("Invalid readreceiptid userid bytes in db."))?;
let user_id = UserId::parse(user_id_str)
.map_err(|_| Error::bad_database("Invalid readreceiptid userid in db."))?;
let mut json = serde_json::from_slice::<CanonicalJsonObject>(&v)
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
json.remove("room_id");
let mut json = serde_json::from_slice::<CanonicalJsonObject>(v)
.map_err(|_| Error::bad_database("Read receipt in roomlatestid_roomlatest is invalid json."))?;
Ok((
user_id,
count,
Raw::from_json(serde_json::value::to_raw_value(&json).expect("json is valid raw value")),
))
}),
)
json.remove("room_id");
let event = Raw::from_json(serde_json::value::to_raw_value(&json)?);
Ok((user_id, count, event))
})
.ignore_err()
}
pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
pub(super) fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.insert(&key, &count.to_be_bytes())?;
.insert(&key, &count.to_be_bytes());
self.roomuserid_lastprivatereadupdate
.insert(&key, &self.services.globals.next_count()?.to_be_bytes())
.insert(&key, &self.services.globals.next_count().unwrap().to_be_bytes());
}
pub(super) fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_privateread
.get(&key)?
.map_or(Ok(None), |v| {
Ok(Some(
utils::u64_from_bytes(&v).map_err(|_| Error::bad_database("Invalid private read marker bytes"))?,
))
})
pub(super) async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.roomuserid_privateread.qry(&key).await.deserialized()
}
pub(super) fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastprivatereadupdate
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
pub(super) async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
let key = (room_id, user_id);
self.roomuserid_lastprivatereadupdate
.qry(&key)
.await
.deserialized()
.unwrap_or(0)
}
}

View file

@ -3,16 +3,17 @@ mod data;
use std::{collections::BTreeMap, sync::Arc};
use conduit::{debug, Result};
use data::Data;
use futures::Stream;
use ruma::{
events::{
receipt::{ReceiptEvent, ReceiptEventContent},
AnySyncEphemeralRoomEvent, SyncEphemeralRoomEvent,
SyncEphemeralRoomEvent,
},
serde::Raw,
OwnedUserId, RoomId, UserId,
RoomId, UserId,
};
use self::data::{Data, ReceiptItem};
use crate::{sending, Dep};
pub struct Service {
@ -24,9 +25,6 @@ struct Services {
sending: Dep<sending::Service>,
}
type AnySyncEphemeralRoomEventIter<'a> =
Box<dyn Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a>;
impl crate::Service for Service {
fn build(args: crate::Args<'_>) -> Result<Arc<Self>> {
Ok(Arc::new(Self {
@ -42,44 +40,53 @@ impl crate::Service for Service {
impl Service {
/// Replaces the previous read receipt.
pub fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) -> Result<()> {
self.db.readreceipt_update(user_id, room_id, event)?;
self.services.sending.flush_room(room_id)?;
Ok(())
pub async fn readreceipt_update(&self, user_id: &UserId, room_id: &RoomId, event: &ReceiptEvent) {
self.db.readreceipt_update(user_id, room_id, event).await;
self.services
.sending
.flush_room(room_id)
.await
.expect("room flush failed");
}
/// Returns an iterator over the most recent read_receipts in a room that
/// happened after the event with id `since`.
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn readreceipts_since<'a>(
&'a self, room_id: &RoomId, since: u64,
) -> impl Iterator<Item = Result<(OwnedUserId, u64, Raw<AnySyncEphemeralRoomEvent>)>> + 'a {
&'a self, room_id: &'a RoomId, since: u64,
) -> impl Stream<Item = ReceiptItem> + Send + 'a {
self.db.readreceipts_since(room_id, since)
}
/// Sets a private read marker at `count`.
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) -> Result<()> {
self.db.private_read_set(room_id, user_id, count)
pub fn private_read_set(&self, room_id: &RoomId, user_id: &UserId, count: u64) {
self.db.private_read_set(room_id, user_id, count);
}
/// Returns the private read marker.
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
self.db.private_read_get(room_id, user_id)
pub async fn private_read_get(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
self.db.private_read_get(room_id, user_id).await
}
/// Returns the count of the last typing update in this room.
pub fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
self.db.last_privateread_update(user_id, room_id)
#[inline]
pub async fn last_privateread_update(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
self.db.last_privateread_update(user_id, room_id).await
}
}
#[must_use]
pub fn pack_receipts(receipts: AnySyncEphemeralRoomEventIter<'_>) -> Raw<SyncEphemeralRoomEvent<ReceiptEventContent>> {
pub fn pack_receipts<I>(receipts: I) -> Raw<SyncEphemeralRoomEvent<ReceiptEventContent>>
where
I: Iterator<Item = ReceiptItem>,
{
let mut json = BTreeMap::new();
for (_user, _count, value) in receipts.flatten() {
for (_, _, value) in receipts {
let receipt = serde_json::from_str::<SyncEphemeralRoomEvent<ReceiptEventContent>>(value.json().get());
if let Ok(value) = receipt {
for (event, receipt) in value.content {

View file

@ -1,13 +1,12 @@
use std::sync::Arc;
use conduit::{utils, Result};
use conduit::utils::{set, stream::TryIgnore, IterStream, ReadyExt};
use database::Map;
use futures::StreamExt;
use ruma::RoomId;
use crate::{rooms, Dep};
type SearchPdusResult<'a> = Result<Option<(Box<dyn Iterator<Item = Vec<u8>> + 'a>, Vec<String>)>>;
pub(super) struct Data {
tokenids: Arc<Map>,
services: Services,
@ -28,7 +27,7 @@ impl Data {
}
}
pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
pub(super) fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
let batch = tokenize(message_body)
.map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
@ -39,11 +38,10 @@ impl Data {
})
.collect::<Vec<_>>();
self.tokenids
.insert_batch(batch.iter().map(database::KeyVal::from))
self.tokenids.insert_batch(batch.iter());
}
pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
pub(super) fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
let batch = tokenize(message_body).map(|word| {
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(word.as_bytes());
@ -53,46 +51,53 @@ impl Data {
});
for token in batch {
self.tokenids.remove(&token)?;
self.tokenids.remove(&token);
}
Ok(())
}
pub(super) fn search_pdus<'a>(&'a self, room_id: &RoomId, search_string: &str) -> SearchPdusResult<'a> {
pub(super) async fn search_pdus(
&self, room_id: &RoomId, search_string: &str,
) -> Option<(Vec<Vec<u8>>, Vec<String>)> {
let prefix = self
.services
.short
.get_shortroomid(room_id)?
.expect("room exists")
.get_shortroomid(room_id)
.await
.ok()?
.to_be_bytes()
.to_vec();
let words: Vec<_> = tokenize(search_string).collect();
let iterators = words.clone().into_iter().map(move |word| {
let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xFF);
let prefix3 = prefix2.clone();
let bufs: Vec<_> = words
.clone()
.into_iter()
.stream()
.then(move |word| {
let mut prefix2 = prefix.clone();
prefix2.extend_from_slice(word.as_bytes());
prefix2.push(0xFF);
let prefix3 = prefix2.clone();
let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
let mut last_possible_id = prefix2.clone();
last_possible_id.extend_from_slice(&u64::MAX.to_be_bytes());
self.tokenids
.iter_from(&last_possible_id, true) // Newest pdus first
.take_while(move |(k, _)| k.starts_with(&prefix2))
.map(move |(key, _)| key[prefix3.len()..].to_vec())
});
self.tokenids
.rev_raw_keys_from(&last_possible_id) // Newest pdus first
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix2))
.map(move |key| key[prefix3.len()..].to_vec())
.collect::<Vec<_>>()
})
.collect()
.await;
let Some(common_elements) = utils::common_elements(iterators, |a, b| {
// We compare b with a because we reversed the iterator earlier
b.cmp(a)
}) else {
return Ok(None);
};
Ok(Some((Box::new(common_elements), words)))
Some((
set::intersection(bufs.iter().map(|buf| buf.iter()))
.cloned()
.collect(),
words,
))
}
}
@ -100,7 +105,7 @@ impl Data {
///
/// This may be used to tokenize both message bodies (for indexing) or search
/// queries (for querying).
fn tokenize(body: &str) -> impl Iterator<Item = String> + '_ {
fn tokenize(body: &str) -> impl Iterator<Item = String> + Send + '_ {
body.split_terminator(|c: char| !c.is_alphanumeric())
.filter(|s| !s.is_empty())
.filter(|word| word.len() <= 50)

View file

@ -21,20 +21,21 @@ impl crate::Service for Service {
}
impl Service {
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
self.db.index_pdu(shortroomid, pdu_id, message_body)
pub fn index_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
self.db.index_pdu(shortroomid, pdu_id, message_body);
}
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) -> Result<()> {
self.db.deindex_pdu(shortroomid, pdu_id, message_body)
pub fn deindex_pdu(&self, shortroomid: u64, pdu_id: &[u8], message_body: &str) {
self.db.deindex_pdu(shortroomid, pdu_id, message_body);
}
#[inline]
#[tracing::instrument(skip(self), level = "debug")]
pub fn search_pdus<'a>(
&'a self, room_id: &RoomId, search_string: &str,
) -> Result<Option<(impl Iterator<Item = Vec<u8>> + 'a, Vec<String>)>> {
self.db.search_pdus(room_id, search_string)
pub async fn search_pdus(&self, room_id: &RoomId, search_string: &str) -> Option<(Vec<Vec<u8>>, Vec<String>)> {
self.db.search_pdus(room_id, search_string).await
}
}

View file

@ -1,7 +1,7 @@
use std::sync::Arc;
use conduit::{utils, warn, Error, Result};
use database::Map;
use conduit::{err, utils, Error, Result};
use database::{Deserialized, Map};
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{globals, Dep};
@ -36,44 +36,46 @@ impl Data {
}
}
pub(super) fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
let short = if let Some(shorteventid) = self.eventid_shorteventid.get(event_id.as_bytes())? {
utils::u64_from_bytes(&shorteventid).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?
} else {
let shorteventid = self.services.globals.next_count()?;
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes())?;
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes())?;
shorteventid
};
pub(super) async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 {
if let Ok(shorteventid) = self.eventid_shorteventid.qry(event_id).await.deserialized() {
return shorteventid;
}
Ok(short)
let shorteventid = self.services.globals.next_count().unwrap();
self.eventid_shorteventid
.insert(event_id.as_bytes(), &shorteventid.to_be_bytes());
self.shorteventid_eventid
.insert(&shorteventid.to_be_bytes(), event_id.as_bytes());
shorteventid
}
pub(super) fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> {
pub(super) async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec<u64> {
let mut ret: Vec<u64> = Vec::with_capacity(event_ids.len());
let keys = event_ids
.iter()
.map(|id| id.as_bytes())
.collect::<Vec<&[u8]>>();
for (i, short) in self
.eventid_shorteventid
.multi_get(&keys)?
.multi_get(keys.iter())
.iter()
.enumerate()
{
#[allow(clippy::single_match_else)]
match short {
Some(short) => ret.push(
utils::u64_from_bytes(short).map_err(|_| Error::bad_database("Invalid shorteventid in db."))?,
utils::u64_from_bytes(short)
.map_err(|_| Error::bad_database("Invalid shorteventid in db."))
.unwrap(),
),
None => {
let short = self.services.globals.next_count()?;
let short = self.services.globals.next_count().unwrap();
self.eventid_shorteventid
.insert(keys[i], &short.to_be_bytes())?;
.insert(keys[i], &short.to_be_bytes());
self.shorteventid_eventid
.insert(&short.to_be_bytes(), keys[i])?;
.insert(&short.to_be_bytes(), keys[i]);
debug_assert!(ret.len() == i, "position of result must match input");
ret.push(short);
@ -81,115 +83,85 @@ impl Data {
}
}
Ok(ret)
ret
}
pub(super) fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
pub(super) async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
let key = (event_type, state_key);
self.statekey_shortstatekey.qry(&key).await.deserialized()
}
let short = self
.statekey_shortstatekey
.get(&statekey_vec)?
.map(|shortstatekey| {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))
pub(super) async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 {
let key = (event_type.to_string(), state_key);
if let Ok(shortstatekey) = self.statekey_shortstatekey.qry(&key).await.deserialized() {
return shortstatekey;
}
let mut key = event_type.to_string().as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(state_key.as_bytes());
let shortstatekey = self.services.globals.next_count().unwrap();
self.statekey_shortstatekey
.insert(&key, &shortstatekey.to_be_bytes());
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &key);
shortstatekey
}
pub(super) async fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
self.shorteventid_eventid
.qry(&shorteventid)
.await
.deserialized()
.map_err(|e| err!(Database("Failed to find EventId from short {shorteventid:?}: {e:?}")))
}
pub(super) async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
self.shortstatekey_statekey
.qry(&shortstatekey)
.await
.deserialized()
.map_err(|e| {
err!(Database(
"Failed to find (StateEventType, state_key) from short {shortstatekey:?}: {e:?}"
))
})
.transpose()?;
Ok(short)
}
pub(super) fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
let mut statekey_vec = event_type.to_string().as_bytes().to_vec();
statekey_vec.push(0xFF);
statekey_vec.extend_from_slice(state_key.as_bytes());
let short = if let Some(shortstatekey) = self.statekey_shortstatekey.get(&statekey_vec)? {
utils::u64_from_bytes(&shortstatekey).map_err(|_| Error::bad_database("Invalid shortstatekey in db."))?
} else {
let shortstatekey = self.services.globals.next_count()?;
self.statekey_shortstatekey
.insert(&statekey_vec, &shortstatekey.to_be_bytes())?;
self.shortstatekey_statekey
.insert(&shortstatekey.to_be_bytes(), &statekey_vec)?;
shortstatekey
};
Ok(short)
}
pub(super) fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
let bytes = self
.shorteventid_eventid
.get(&shorteventid.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shorteventid does not exist"))?;
let event_id = EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in shorteventid_eventid is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in shorteventid_eventid is invalid."))?;
Ok(event_id)
}
pub(super) fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
let bytes = self
.shortstatekey_statekey
.get(&shortstatekey.to_be_bytes())?
.ok_or_else(|| Error::bad_database("Shortstatekey does not exist"))?;
let mut parts = bytes.splitn(2, |&b| b == 0xFF);
let eventtype_bytes = parts.next().expect("split always returns one entry");
let statekey_bytes = parts
.next()
.ok_or_else(|| Error::bad_database("Invalid statekey in shortstatekey_statekey."))?;
let event_type = StateEventType::from(utils::string_from_bytes(eventtype_bytes).map_err(|e| {
warn!("Event type in shortstatekey_statekey is invalid: {}", e);
Error::bad_database("Event type in shortstatekey_statekey is invalid.")
})?);
let state_key = utils::string_from_bytes(statekey_bytes)
.map_err(|_| Error::bad_database("Statekey in shortstatekey_statekey is invalid unicode."))?;
let result = (event_type, state_key);
Ok(result)
}
/// Returns (shortstatehash, already_existed)
pub(super) fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
Ok(if let Some(shortstatehash) = self.statehash_shortstatehash.get(state_hash)? {
(
utils::u64_from_bytes(&shortstatehash)
.map_err(|_| Error::bad_database("Invalid shortstatehash in db."))?,
true,
)
} else {
let shortstatehash = self.services.globals.next_count()?;
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes())?;
(shortstatehash, false)
})
pub(super) async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) {
if let Ok(shortstatehash) = self
.statehash_shortstatehash
.qry(state_hash)
.await
.deserialized()
{
return (shortstatehash, true);
}
let shortstatehash = self.services.globals.next_count().unwrap();
self.statehash_shortstatehash
.insert(state_hash, &shortstatehash.to_be_bytes());
(shortstatehash, false)
}
pub(super) fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> {
pub(super) async fn get_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
self.roomid_shortroomid.qry(room_id).await.deserialized()
}
pub(super) async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 {
self.roomid_shortroomid
.get(room_id.as_bytes())?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid shortroomid in db.")))
.transpose()
}
pub(super) fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
Ok(if let Some(short) = self.roomid_shortroomid.get(room_id.as_bytes())? {
utils::u64_from_bytes(&short).map_err(|_| Error::bad_database("Invalid shortroomid in db."))?
} else {
let short = self.services.globals.next_count()?;
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes())?;
short
})
.qry(room_id)
.await
.deserialized()
.unwrap_or_else(|_| {
let short = self.services.globals.next_count().unwrap();
self.roomid_shortroomid
.insert(room_id.as_bytes(), &short.to_be_bytes());
short
})
}
}

View file

@ -22,38 +22,40 @@ impl crate::Service for Service {
}
impl Service {
pub fn get_or_create_shorteventid(&self, event_id: &EventId) -> Result<u64> {
self.db.get_or_create_shorteventid(event_id)
pub async fn get_or_create_shorteventid(&self, event_id: &EventId) -> u64 {
self.db.get_or_create_shorteventid(event_id).await
}
pub fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Result<Vec<u64>> {
self.db.multi_get_or_create_shorteventid(event_ids)
pub async fn multi_get_or_create_shorteventid(&self, event_ids: &[&EventId]) -> Vec<u64> {
self.db.multi_get_or_create_shorteventid(event_ids).await
}
pub fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<Option<u64>> {
self.db.get_shortstatekey(event_type, state_key)
pub async fn get_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
self.db.get_shortstatekey(event_type, state_key).await
}
pub fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> Result<u64> {
self.db.get_or_create_shortstatekey(event_type, state_key)
pub async fn get_or_create_shortstatekey(&self, event_type: &StateEventType, state_key: &str) -> u64 {
self.db
.get_or_create_shortstatekey(event_type, state_key)
.await
}
pub fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
self.db.get_eventid_from_short(shorteventid)
pub async fn get_eventid_from_short(&self, shorteventid: u64) -> Result<Arc<EventId>> {
self.db.get_eventid_from_short(shorteventid).await
}
pub fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
self.db.get_statekey_from_short(shortstatekey)
pub async fn get_statekey_from_short(&self, shortstatekey: u64) -> Result<(StateEventType, String)> {
self.db.get_statekey_from_short(shortstatekey).await
}
/// Returns (shortstatehash, already_existed)
pub fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> Result<(u64, bool)> {
self.db.get_or_create_shortstatehash(state_hash)
pub async fn get_or_create_shortstatehash(&self, state_hash: &[u8]) -> (u64, bool) {
self.db.get_or_create_shortstatehash(state_hash).await
}
pub fn get_shortroomid(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.get_shortroomid(room_id) }
pub async fn get_shortroomid(&self, room_id: &RoomId) -> Result<u64> { self.db.get_shortroomid(room_id).await }
pub fn get_or_create_shortroomid(&self, room_id: &RoomId) -> Result<u64> {
self.db.get_or_create_shortroomid(room_id)
pub async fn get_or_create_shortroomid(&self, room_id: &RoomId) -> u64 {
self.db.get_or_create_shortroomid(room_id).await
}
}

View file

@ -7,7 +7,12 @@ use std::{
sync::Arc,
};
use conduit::{checked, debug, debug_info, err, utils::math::usize_from_f64, warn, Error, Result};
use conduit::{
checked, debug, debug_info, err,
utils::{math::usize_from_f64, IterStream},
Error, Result,
};
use futures::{StreamExt, TryFutureExt};
use lru_cache::LruCache;
use ruma::{
api::{
@ -211,12 +216,15 @@ impl Service {
.as_ref()
{
return Ok(if let Some(cached) = cached {
if self.is_accessible_child(
current_room,
&cached.summary.join_rule,
&identifier,
&cached.summary.allowed_room_ids,
) {
if self
.is_accessible_child(
current_room,
&cached.summary.join_rule,
&identifier,
&cached.summary.allowed_room_ids,
)
.await
{
Some(SummaryAccessibility::Accessible(Box::new(cached.summary.clone())))
} else {
Some(SummaryAccessibility::Inaccessible)
@ -228,7 +236,9 @@ impl Service {
Ok(
if let Some(children_pdus) = self.get_stripped_space_child_events(current_room).await? {
let summary = self.get_room_summary(current_room, children_pdus, &identifier);
let summary = self
.get_room_summary(current_room, children_pdus, &identifier)
.await;
if let Ok(summary) = summary {
self.roomid_spacehierarchy_cache.lock().await.insert(
current_room.clone(),
@ -322,12 +332,15 @@ impl Service {
);
}
}
if self.is_accessible_child(
current_room,
&response.room.join_rule,
&Identifier::UserId(user_id),
&response.room.allowed_room_ids,
) {
if self
.is_accessible_child(
current_room,
&response.room.join_rule,
&Identifier::UserId(user_id),
&response.room.allowed_room_ids,
)
.await
{
return Ok(Some(SummaryAccessibility::Accessible(Box::new(summary.clone()))));
}
@ -358,7 +371,7 @@ impl Service {
}
}
fn get_room_summary(
async fn get_room_summary(
&self, current_room: &OwnedRoomId, children_state: Vec<Raw<HierarchySpaceChildEvent>>,
identifier: &Identifier<'_>,
) -> Result<SpaceHierarchyParentSummary, Error> {
@ -367,48 +380,43 @@ impl Service {
let join_rule = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| {
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")
.await
.map_or(JoinRule::Invite, |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| c.join_rule)
.map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}"))))
})
.transpose()?
.unwrap_or(JoinRule::Invite);
.unwrap()
});
let allowed_room_ids = self
.services
.state_accessor
.allowed_room_ids(join_rule.clone());
if !self.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids) {
if !self
.is_accessible_child(current_room, &join_rule.clone().into(), identifier, &allowed_room_ids)
.await
{
debug!("User is not allowed to see room {room_id}");
// This error will be caught later
return Err(Error::BadRequest(ErrorKind::forbidden(), "User is not allowed to see the room"));
}
let join_rule = join_rule.into();
Ok(SpaceHierarchyParentSummary {
canonical_alias: self
.services
.state_accessor
.get_canonical_alias(room_id)
.unwrap_or(None),
name: self
.services
.state_accessor
.get_name(room_id)
.unwrap_or(None),
.await
.ok(),
name: self.services.state_accessor.get_name(room_id).await.ok(),
num_joined_members: self
.services
.state_cache
.room_joined_count(room_id)
.unwrap_or_default()
.unwrap_or_else(|| {
warn!("Room {room_id} has no member count");
0
})
.await
.unwrap_or(0)
.try_into()
.expect("user count should not be that big"),
room_id: room_id.to_owned(),
@ -416,18 +424,29 @@ impl Service {
.services
.state_accessor
.get_room_topic(room_id)
.unwrap_or(None),
world_readable: self.services.state_accessor.is_world_readable(room_id)?,
guest_can_join: self.services.state_accessor.guest_can_join(room_id)?,
.await
.ok(),
world_readable: self
.services
.state_accessor
.is_world_readable(room_id)
.await,
guest_can_join: self.services.state_accessor.guest_can_join(room_id).await,
avatar_url: self
.services
.state_accessor
.get_avatar(room_id)?
.get_avatar(room_id)
.await
.into_option()
.unwrap_or_default()
.url,
join_rule,
room_type: self.services.state_accessor.get_room_type(room_id)?,
join_rule: join_rule.into(),
room_type: self
.services
.state_accessor
.get_room_type(room_id)
.await
.ok(),
children_state,
allowed_room_ids,
})
@ -474,21 +493,22 @@ impl Service {
results.push(summary_to_chunk(*summary.clone()));
} else {
children = children
.into_iter()
.rev()
.skip_while(|(room, _)| {
if let Ok(short) = self.services.short.get_shortroomid(room)
{
short.as_ref() != short_room_ids.get(parents.len())
} else {
false
}
})
.collect::<Vec<_>>()
// skip_while doesn't implement DoubleEndedIterator, which is needed for rev
.into_iter()
.rev()
.collect();
.iter()
.rev()
.stream()
.skip_while(|(room, _)| {
self.services
.short
.get_shortroomid(room)
.map_ok(|short| Some(&short) != short_room_ids.get(parents.len()))
.unwrap_or_else(|_| false)
})
.map(Clone::clone)
.collect::<Vec<(OwnedRoomId, Vec<OwnedServerName>)>>()
.await
.into_iter()
.rev()
.collect();
if children.is_empty() {
return Err(Error::BadRequest(
@ -531,7 +551,7 @@ impl Service {
let mut short_room_ids = vec![];
for room in parents {
short_room_ids.push(self.services.short.get_or_create_shortroomid(&room)?);
short_room_ids.push(self.services.short.get_or_create_shortroomid(&room).await);
}
Some(
@ -554,7 +574,7 @@ impl Service {
async fn get_stripped_space_child_events(
&self, room_id: &RoomId,
) -> Result<Option<Vec<Raw<HierarchySpaceChildEvent>>>, Error> {
let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? else {
let Ok(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id).await else {
return Ok(None);
};
@ -562,10 +582,13 @@ impl Service {
.services
.state_accessor
.state_full_ids(current_shortstatehash)
.await?;
.await
.map_err(|e| err!(Database("State in space not found: {e}")))?;
let mut children_pdus = Vec::new();
for (key, id) in state {
let (event_type, state_key) = self.services.short.get_statekey_from_short(key)?;
let (event_type, state_key) = self.services.short.get_statekey_from_short(key).await?;
if event_type != StateEventType::SpaceChild {
continue;
}
@ -573,8 +596,9 @@ impl Service {
let pdu = self
.services
.timeline
.get_pdu(&id)?
.ok_or_else(|| Error::bad_database("Event in space state not found"))?;
.get_pdu(&id)
.await
.map_err(|e| err!(Database("Event {id:?} in space state not found: {e:?}")))?;
if serde_json::from_str::<SpaceChildEventContent>(pdu.content.get())
.ok()
@ -593,7 +617,7 @@ impl Service {
}
/// With the given identifier, checks if a room is accessable
fn is_accessible_child(
async fn is_accessible_child(
&self, current_room: &OwnedRoomId, join_rule: &SpaceRoomJoinRule, identifier: &Identifier<'_>,
allowed_room_ids: &Vec<OwnedRoomId>,
) -> bool {
@ -607,6 +631,7 @@ impl Service {
.services
.event_handler
.acl_check(server_name, room_id)
.await
.is_err()
{
return false;
@ -617,12 +642,11 @@ impl Service {
.services
.state_cache
.is_joined(user_id, current_room)
.unwrap_or_default()
|| self
.services
.state_cache
.is_invited(user_id, current_room)
.unwrap_or_default()
.await || self
.services
.state_cache
.is_invited(user_id, current_room)
.await
{
return true;
}
@ -633,22 +657,12 @@ impl Service {
for room in allowed_room_ids {
match identifier {
Identifier::UserId(user) => {
if self
.services
.state_cache
.is_joined(user, room)
.unwrap_or_default()
{
if self.services.state_cache.is_joined(user, room).await {
return true;
}
},
Identifier::ServerName(server) => {
if self
.services
.state_cache
.server_in_room(server, room)
.unwrap_or_default()
{
if self.services.state_cache.server_in_room(server, room).await {
return true;
}
},

View file

@ -1,34 +1,31 @@
use std::{collections::HashSet, sync::Arc};
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::{Database, Map};
use ruma::{EventId, OwnedEventId, RoomId};
use conduit::{
utils::{stream::TryIgnore, ReadyExt},
Result,
};
use database::{Database, Deserialized, Interfix, Map};
use ruma::{OwnedEventId, RoomId};
use super::RoomMutexGuard;
pub(super) struct Data {
shorteventid_shortstatehash: Arc<Map>,
roomid_pduleaves: Arc<Map>,
roomid_shortstatehash: Arc<Map>,
pub(super) roomid_pduleaves: Arc<Map>,
}
impl Data {
pub(super) fn new(db: &Arc<Database>) -> Self {
Self {
shorteventid_shortstatehash: db["shorteventid_shortstatehash"].clone(),
roomid_pduleaves: db["roomid_pduleaves"].clone(),
roomid_shortstatehash: db["roomid_shortstatehash"].clone(),
roomid_pduleaves: db["roomid_pduleaves"].clone(),
}
}
pub(super) fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_shortstatehash
.get(room_id.as_bytes())?
.map_or(Ok(None), |bytes| {
Ok(Some(utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash in roomid_shortstatehash")
})?))
})
pub(super) async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<u64> {
self.roomid_shortstatehash.qry(room_id).await.deserialized()
}
#[inline]
@ -37,53 +34,35 @@ impl Data {
room_id: &RoomId,
new_shortstatehash: u64,
_mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
) {
self.roomid_shortstatehash
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes())?;
Ok(())
.insert(room_id.as_bytes(), &new_shortstatehash.to_be_bytes());
}
pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) -> Result<()> {
pub(super) fn set_event_state(&self, shorteventid: u64, shortstatehash: u64) {
self.shorteventid_shortstatehash
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes())?;
Ok(())
.insert(&shorteventid.to_be_bytes(), &shortstatehash.to_be_bytes());
}
pub(super) fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
self.roomid_pduleaves
.scan_prefix(prefix)
.map(|(_, bytes)| {
EventId::parse_arc(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("EventID in roomid_pduleaves is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("EventId in roomid_pduleaves is invalid."))
})
.collect()
}
pub(super) fn set_forward_extremities(
pub(super) async fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
_mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
) {
let prefix = (room_id, Interfix);
self.roomid_pduleaves
.keys_raw_prefix(&prefix)
.ignore_err()
.ready_for_each(|key| self.roomid_pduleaves.remove(key))
.await;
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
for (key, _) in self.roomid_pduleaves.scan_prefix(prefix.clone()) {
self.roomid_pduleaves.remove(&key)?;
}
for event_id in event_ids {
let mut key = prefix.clone();
key.extend_from_slice(event_id.as_bytes());
self.roomid_pduleaves.insert(&key, event_id.as_bytes())?;
self.roomid_pduleaves.insert(&key, event_id.as_bytes());
}
Ok(())
}
}

View file

@ -7,12 +7,14 @@ use std::{
};
use conduit::{
utils::{calculate_hash, MutexMap, MutexMapGuard},
warn, Error, PduEvent, Result,
err,
utils::{calculate_hash, stream::TryIgnore, IterStream, MutexMap, MutexMapGuard},
warn, PduEvent, Result,
};
use data::Data;
use database::{Ignore, Interfix};
use futures::{pin_mut, FutureExt, Stream, StreamExt, TryFutureExt, TryStreamExt};
use ruma::{
api::client::error::ErrorKind,
events::{
room::{create::RoomCreateEventContent, member::RoomMemberEventContent},
AnyStrippedStateEvent, StateEventType, TimelineEventType,
@ -81,14 +83,16 @@ impl Service {
_statediffremoved: Arc<HashSet<CompressedStateEvent>>,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
for event_id in statediffnew.iter().filter_map(|new| {
let event_ids = statediffnew.iter().stream().filter_map(|new| {
self.services
.state_compressor
.parse_compressed_state_event(new)
.ok()
.map(|(_, id)| id)
}) {
let Some(pdu) = self.services.timeline.get_pdu_json(&event_id)? else {
.map_ok_or_else(|_| None, |(_, event_id)| Some(event_id))
});
pin_mut!(event_ids);
while let Some(event_id) = event_ids.next().await {
let Ok(pdu) = self.services.timeline.get_pdu_json(&event_id).await else {
continue;
};
@ -113,15 +117,10 @@ impl Service {
continue;
};
self.services.state_cache.update_membership(
room_id,
&user_id,
membership_event,
&pdu.sender,
None,
None,
false,
)?;
self.services
.state_cache
.update_membership(room_id, &user_id, membership_event, &pdu.sender, None, None, false)
.await?;
},
TimelineEventType::SpaceChild => {
self.services
@ -135,10 +134,9 @@ impl Service {
}
}
self.services.state_cache.update_joined_count(room_id)?;
self.services.state_cache.update_joined_count(room_id).await;
self.db
.set_room_state(room_id, shortstatehash, state_lock)?;
self.db.set_room_state(room_id, shortstatehash, state_lock);
Ok(())
}
@ -148,12 +146,16 @@ impl Service {
/// This adds all current state events (not including the incoming event)
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, state_ids_compressed), level = "debug")]
pub fn set_event_state(
pub async fn set_event_state(
&self, event_id: &EventId, room_id: &RoomId, state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> Result<u64> {
let shorteventid = self.services.short.get_or_create_shorteventid(event_id)?;
let shorteventid = self
.services
.short
.get_or_create_shorteventid(event_id)
.await;
let previous_shortstatehash = self.db.get_room_shortstatehash(room_id)?;
let previous_shortstatehash = self.db.get_room_shortstatehash(room_id).await;
let state_hash = calculate_hash(
&state_ids_compressed
@ -165,13 +167,18 @@ impl Service {
let (shortstatehash, already_existed) = self
.services
.short
.get_or_create_shortstatehash(&state_hash)?;
.get_or_create_shortstatehash(&state_hash)
.await;
if !already_existed {
let states_parents = previous_shortstatehash.map_or_else(
|| Ok(Vec::new()),
|p| self.services.state_compressor.load_shortstatehash_info(p),
)?;
let states_parents = if let Ok(p) = previous_shortstatehash {
self.services
.state_compressor
.load_shortstatehash_info(p)
.await?
} else {
Vec::new()
};
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = state_ids_compressed
@ -198,7 +205,7 @@ impl Service {
)?;
}
self.db.set_event_state(shorteventid, shortstatehash)?;
self.db.set_event_state(shorteventid, shortstatehash);
Ok(shortstatehash)
}
@ -208,34 +215,40 @@ impl Service {
/// This adds all current state events (not including the incoming event)
/// to `stateid_pduid` and adds the incoming event to `eventid_statehash`.
#[tracing::instrument(skip(self, new_pdu), level = "debug")]
pub fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
pub async fn append_to_state(&self, new_pdu: &PduEvent) -> Result<u64> {
let shorteventid = self
.services
.short
.get_or_create_shorteventid(&new_pdu.event_id)?;
.get_or_create_shorteventid(&new_pdu.event_id)
.await;
let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id)?;
let previous_shortstatehash = self.get_room_shortstatehash(&new_pdu.room_id).await;
if let Some(p) = previous_shortstatehash {
self.db.set_event_state(shorteventid, p)?;
if let Ok(p) = previous_shortstatehash {
self.db.set_event_state(shorteventid, p);
}
if let Some(state_key) = &new_pdu.state_key {
let states_parents = previous_shortstatehash.map_or_else(
|| Ok(Vec::new()),
#[inline]
|p| self.services.state_compressor.load_shortstatehash_info(p),
)?;
let states_parents = if let Ok(p) = previous_shortstatehash {
self.services
.state_compressor
.load_shortstatehash_info(p)
.await?
} else {
Vec::new()
};
let shortstatekey = self
.services
.short
.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)?;
.get_or_create_shortstatekey(&new_pdu.kind.to_string().into(), state_key)
.await;
let new = self
.services
.state_compressor
.compress_state_event(shortstatekey, &new_pdu.event_id)?;
.compress_state_event(shortstatekey, &new_pdu.event_id)
.await;
let replaces = states_parents
.last()
@ -276,49 +289,55 @@ impl Service {
}
#[tracing::instrument(skip(self, invite_event), level = "debug")]
pub fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
pub async fn calculate_invite_state(&self, invite_event: &PduEvent) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let mut state = Vec::new();
// Add recommended events
if let Some(e) =
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")?
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomCreate, "")
.await
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")?
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomJoinRules, "")
.await
{
state.push(e.to_stripped_state_event());
}
if let Some(e) = self.services.state_accessor.room_state_get(
&invite_event.room_id,
&StateEventType::RoomCanonicalAlias,
"",
)? {
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")?
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomCanonicalAlias, "")
.await
{
state.push(e.to_stripped_state_event());
}
if let Some(e) =
self.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")?
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomAvatar, "")
.await
{
state.push(e.to_stripped_state_event());
}
if let Some(e) = self.services.state_accessor.room_state_get(
&invite_event.room_id,
&StateEventType::RoomMember,
invite_event.sender.as_str(),
)? {
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomName, "")
.await
{
state.push(e.to_stripped_state_event());
}
if let Ok(e) = self
.services
.state_accessor
.room_state_get(&invite_event.room_id, &StateEventType::RoomMember, invite_event.sender.as_str())
.await
{
state.push(e.to_stripped_state_event());
}
@ -333,101 +352,108 @@ impl Service {
room_id: &RoomId,
shortstatehash: u64,
mutex_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
self.db.set_room_state(room_id, shortstatehash, mutex_lock)
) {
self.db.set_room_state(room_id, shortstatehash, mutex_lock);
}
/// Returns the room's version.
#[tracing::instrument(skip(self), level = "debug")]
pub fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
let create_event = self
.services
pub async fn get_room_version(&self, room_id: &RoomId) -> Result<RoomVersionId> {
self.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?;
let create_event_content: RoomCreateEventContent = create_event
.as_ref()
.map(|create_event| {
serde_json::from_str(create_event.content.get()).map_err(|e| {
warn!("Invalid create event: {}", e);
Error::bad_database("Invalid create event in db.")
})
})
.transpose()?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "No create event found"))?;
Ok(create_event_content.room_version)
.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
.await
.map(|content: RoomCreateEventContent| content.room_version)
.map_err(|e| err!(Request(NotFound("No create event found: {e:?}"))))
}
#[inline]
pub fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.db.get_room_shortstatehash(room_id)
pub async fn get_room_shortstatehash(&self, room_id: &RoomId) -> Result<u64> {
self.db.get_room_shortstatehash(room_id).await
}
pub fn get_forward_extremities(&self, room_id: &RoomId) -> Result<HashSet<Arc<EventId>>> {
self.db.get_forward_extremities(room_id)
pub fn get_forward_extremities<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &EventId> + Send + '_ {
let prefix = (room_id, Interfix);
self.db
.roomid_pduleaves
.keys_prefix(&prefix)
.map_ok(|(_, event_id): (Ignore, &EventId)| event_id)
.ignore_err()
}
pub fn set_forward_extremities(
pub async fn set_forward_extremities(
&self,
room_id: &RoomId,
event_ids: Vec<OwnedEventId>,
state_lock: &RoomMutexGuard, // Take mutex guard to make sure users get the room state mutex
) -> Result<()> {
) {
self.db
.set_forward_extremities(room_id, event_ids, state_lock)
.await;
}
/// This fetches auth events from the current state.
#[tracing::instrument(skip(self), level = "debug")]
pub fn get_auth_events(
pub async fn get_auth_events(
&self, room_id: &RoomId, kind: &TimelineEventType, sender: &UserId, state_key: Option<&str>,
content: &serde_json::value::RawValue,
) -> Result<StateMap<Arc<PduEvent>>> {
let Some(shortstatehash) = self.get_room_shortstatehash(room_id)? else {
let Ok(shortstatehash) = self.get_room_shortstatehash(room_id).await else {
return Ok(HashMap::new());
};
let auth_events =
state_res::auth_types_for_event(kind, sender, state_key, content).expect("content is a valid JSON object");
let auth_events = state_res::auth_types_for_event(kind, sender, state_key, content)?;
let mut sauthevents = auth_events
.into_iter()
let mut sauthevents: HashMap<_, _> = auth_events
.iter()
.stream()
.filter_map(|(event_type, state_key)| {
self.services
.short
.get_shortstatekey(&event_type.to_string().into(), &state_key)
.ok()
.flatten()
.map(|s| (s, (event_type, state_key)))
.get_shortstatekey(event_type, state_key)
.map_ok(move |s| (s, (event_type, state_key)))
.map(Result::ok)
})
.collect::<HashMap<_, _>>();
.collect()
.await;
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| {
err!(Database(
"Missing shortstatehash info for {room_id:?} at {shortstatehash:?}: {e:?}"
))
})?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
.iter()
.filter_map(|compressed| {
self.services
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
})
.filter_map(|(shortstatekey, event_id)| sauthevents.remove(&shortstatekey).map(|k| (k, event_id)))
.filter_map(|(k, event_id)| {
self.services
.timeline
.get_pdu(&event_id)
.ok()
.flatten()
.map(|pdu| (k, pdu))
})
.collect())
let mut ret = HashMap::new();
for compressed in full_state.iter() {
let Ok((shortstatekey, event_id)) = self
.services
.state_compressor
.parse_compressed_state_event(compressed)
.await
else {
continue;
};
let Some((ty, state_key)) = sauthevents.remove(&shortstatekey) else {
continue;
};
let Ok(pdu) = self.services.timeline.get_pdu(&event_id).await else {
continue;
};
ret.insert((ty.to_owned(), state_key.to_owned()), pdu);
}
Ok(ret)
}
}

View file

@ -1,7 +1,8 @@
use std::{collections::HashMap, sync::Arc};
use conduit::{utils, Error, PduEvent, Result};
use database::Map;
use conduit::{err, PduEvent, Result};
use database::{Deserialized, Map};
use futures::TryFutureExt;
use ruma::{events::StateEventType, EventId, RoomId};
use crate::{rooms, Dep};
@ -39,17 +40,22 @@ impl Data {
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| err!(Database("Missing state IDs: {e}")))?
.pop()
.expect("there is always one layer")
.1;
let mut result = HashMap::new();
let mut i: u8 = 0;
for compressed in full_state.iter() {
let parsed = self
.services
.state_compressor
.parse_compressed_state_event(compressed)?;
.parse_compressed_state_event(compressed)
.await?;
result.insert(parsed.0, parsed.1);
i = i.wrapping_add(1);
@ -57,6 +63,7 @@ impl Data {
tokio::task::yield_now().await;
}
}
Ok(result)
}
@ -67,7 +74,8 @@ impl Data {
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.load_shortstatehash_info(shortstatehash)
.await?
.pop()
.expect("there is always one layer")
.1;
@ -78,18 +86,13 @@ impl Data {
let (_, eventid) = self
.services
.state_compressor
.parse_compressed_state_event(compressed)?;
if let Some(pdu) = self.services.timeline.get_pdu(&eventid)? {
result.insert(
(
pdu.kind.to_string().into(),
pdu.state_key
.as_ref()
.ok_or_else(|| Error::bad_database("State event has no state key."))?
.clone(),
),
pdu,
);
.parse_compressed_state_event(compressed)
.await?;
if let Ok(pdu) = self.services.timeline.get_pdu(&eventid).await {
if let Some(state_key) = pdu.state_key.as_ref() {
result.insert((pdu.kind.to_string().into(), state_key.clone()), pdu);
}
}
i = i.wrapping_add(1);
@ -101,61 +104,63 @@ impl Data {
Ok(result)
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
#[allow(clippy::unused_self)]
pub(super) fn state_get_id(
pub(super) async fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
let Some(shortstatekey) = self
) -> Result<Arc<EventId>> {
let shortstatekey = self
.services
.short
.get_shortstatekey(event_type, state_key)?
else {
return Ok(None);
};
.get_shortstatekey(event_type, state_key)
.await?;
let full_state = self
.services
.state_compressor
.load_shortstatehash_info(shortstatehash)?
.load_shortstatehash_info(shortstatehash)
.await
.map_err(|e| err!(Database(error!(?event_type, ?state_key, "Missing state: {e:?}"))))?
.pop()
.expect("there is always one layer")
.1;
Ok(full_state
let compressed = full_state
.iter()
.find(|bytes| bytes.starts_with(&shortstatekey.to_be_bytes()))
.and_then(|compressed| {
self.services
.state_compressor
.parse_compressed_state_event(compressed)
.ok()
.map(|(_, id)| id)
}))
.ok_or(err!(Database("No shortstatekey in compressed state")))?;
self.services
.state_compressor
.parse_compressed_state_event(compressed)
.map_ok(|(_, id)| id)
.map_err(|e| {
err!(Database(error!(
?event_type,
?state_key,
?shortstatekey,
"Failed to parse compressed: {e:?}"
)))
})
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
pub(super) fn state_get(
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.state_get_id(shortstatehash, event_type, state_key)?
.map_or(Ok(None), |event_id| self.services.timeline.get_pdu(&event_id))
) -> Result<Arc<PduEvent>> {
self.state_get_id(shortstatehash, event_type, state_key)
.and_then(|event_id| async move { self.services.timeline.get_pdu(&event_id).await })
.await
}
/// Returns the state hash for this pdu.
pub(super) fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> {
pub(super) async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<u64> {
self.eventid_shorteventid
.get(event_id.as_bytes())?
.map_or(Ok(None), |shorteventid| {
self.shorteventid_shortstatehash
.get(&shorteventid)?
.map(|bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| {
Error::bad_database("Invalid shortstatehash bytes in shorteventid_shortstatehash")
})
})
.transpose()
})
.qry(event_id)
.and_then(|shorteventid| self.shorteventid_shortstatehash.qry(&shorteventid))
.await
.deserialized()
}
/// Returns the full room state.
@ -163,34 +168,33 @@ impl Data {
pub(super) async fn room_state_full(
&self, room_id: &RoomId,
) -> Result<HashMap<(StateEventType, String), Arc<PduEvent>>> {
if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_full(current_shortstatehash).await
} else {
Ok(HashMap::new())
}
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_full(shortstatehash))
.map_err(|e| err!(Database("Missing state for {room_id:?}: {e:?}")))
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
pub(super) fn room_state_get_id(
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_get_id(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
) -> Result<Arc<EventId>> {
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_get_id(shortstatehash, event_type, state_key))
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
pub(super) fn room_state_get(
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub(super) async fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
if let Some(current_shortstatehash) = self.services.state.get_room_shortstatehash(room_id)? {
self.state_get(current_shortstatehash, event_type, state_key)
} else {
Ok(None)
}
) -> Result<Arc<PduEvent>> {
self.services
.state
.get_room_shortstatehash(room_id)
.and_then(|shortstatehash| self.state_get(shortstatehash, event_type, state_key))
.await
}
}

View file

@ -6,8 +6,13 @@ use std::{
sync::{Arc, Mutex as StdMutex, Mutex},
};
use conduit::{err, error, pdu::PduBuilder, utils::math::usize_from_f64, warn, Error, PduEvent, Result};
use data::Data;
use conduit::{
err, error,
pdu::PduBuilder,
utils::{math::usize_from_f64, ReadyExt},
Error, PduEvent, Result,
};
use futures::StreamExt;
use lru_cache::LruCache;
use ruma::{
events::{
@ -31,8 +36,10 @@ use ruma::{
EventEncryptionAlgorithm, EventId, OwnedRoomAliasId, OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName,
UserId,
};
use serde::Deserialize;
use serde_json::value::to_raw_value;
use self::data::Data;
use crate::{rooms, rooms::state::RoomMutexGuard, Dep};
pub struct Service {
@ -99,54 +106,58 @@ impl Service {
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub fn state_get_id(
pub async fn state_get_id(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
self.db.state_get_id(shortstatehash, event_type, state_key)
) -> Result<Arc<EventId>> {
self.db
.state_get_id(shortstatehash, event_type, state_key)
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
#[inline]
pub fn state_get(
pub async fn state_get(
&self, shortstatehash: u64, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.db.state_get(shortstatehash, event_type, state_key)
) -> Result<Arc<PduEvent>> {
self.db
.state_get(shortstatehash, event_type, state_key)
.await
}
/// Get membership for given user in state
fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> Result<MembershipState> {
self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())?
.map_or(Ok(MembershipState::Leave), |s| {
async fn user_membership(&self, shortstatehash: u64, user_id: &UserId) -> MembershipState {
self.state_get(shortstatehash, &StateEventType::RoomMember, user_id.as_str())
.await
.map_or(MembershipState::Leave, |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomMemberEventContent| c.membership)
.map_err(|_| Error::bad_database("Invalid room membership event in database."))
.unwrap()
})
}
/// The user was a joined member at this state (potentially in the past)
#[inline]
fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id)
.is_ok_and(|s| s == MembershipState::Join)
// Return sensible default, i.e.
// false
async fn user_was_joined(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id).await == MembershipState::Join
}
/// The user was an invited or joined room member at this state (potentially
/// in the past)
#[inline]
fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool {
self.user_membership(shortstatehash, user_id)
.is_ok_and(|s| s == MembershipState::Join || s == MembershipState::Invite)
// Return sensible default, i.e. false
async fn user_was_invited(&self, shortstatehash: u64, user_id: &UserId) -> bool {
let s = self.user_membership(shortstatehash, user_id).await;
s == MembershipState::Join || s == MembershipState::Invite
}
/// Whether a server is allowed to see an event through federation, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, origin, room_id, event_id))]
pub fn server_can_see_event(&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else {
pub async fn server_can_see_event(
&self, origin: &ServerName, room_id: &RoomId, event_id: &EventId,
) -> Result<bool> {
let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else {
return Ok(true);
};
@ -160,8 +171,9 @@ impl Service {
}
let history_visibility = self
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(HistoryVisibility::Shared), |s| {
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")
.await
.map_or(HistoryVisibility::Shared, |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
.map_err(|e| {
@ -171,25 +183,28 @@ impl Service {
);
Error::bad_database("Invalid history visibility event in database.")
})
})
.unwrap_or(HistoryVisibility::Shared);
.unwrap()
});
let mut current_server_members = self
let current_server_members = self
.services
.state_cache
.room_members(room_id)
.filter_map(Result::ok)
.filter(|member| member.server_name() == origin);
.ready_filter(|member| member.server_name() == origin);
let visibility = match history_visibility {
HistoryVisibility::WorldReadable | HistoryVisibility::Shared => true,
HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
current_server_members.any(|member| self.user_was_invited(shortstatehash, &member))
current_server_members
.any(|member| self.user_was_invited(shortstatehash, member))
.await
},
HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
current_server_members.any(|member| self.user_was_joined(shortstatehash, &member))
current_server_members
.any(|member| self.user_was_joined(shortstatehash, member))
.await
},
_ => {
error!("Unknown history visibility {history_visibility}");
@ -208,9 +223,9 @@ impl Service {
/// Whether a user is allowed to see an event, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, user_id, room_id, event_id))]
pub fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> Result<bool> {
let Some(shortstatehash) = self.pdu_shortstatehash(event_id)? else {
return Ok(true);
pub async fn user_can_see_event(&self, user_id: &UserId, room_id: &RoomId, event_id: &EventId) -> bool {
let Ok(shortstatehash) = self.pdu_shortstatehash(event_id).await else {
return true;
};
if let Some(visibility) = self
@ -219,14 +234,15 @@ impl Service {
.unwrap()
.get_mut(&(user_id.to_owned(), shortstatehash))
{
return Ok(*visibility);
return *visibility;
}
let currently_member = self.services.state_cache.is_joined(user_id, room_id)?;
let currently_member = self.services.state_cache.is_joined(user_id, room_id).await;
let history_visibility = self
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(HistoryVisibility::Shared), |s| {
.state_get(shortstatehash, &StateEventType::RoomHistoryVisibility, "")
.await
.map_or(HistoryVisibility::Shared, |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
.map_err(|e| {
@ -236,19 +252,19 @@ impl Service {
);
Error::bad_database("Invalid history visibility event in database.")
})
})
.unwrap_or(HistoryVisibility::Shared);
.unwrap()
});
let visibility = match history_visibility {
HistoryVisibility::WorldReadable => true,
HistoryVisibility::Shared => currently_member,
HistoryVisibility::Invited => {
// Allow if any member on requesting server was AT LEAST invited, else deny
self.user_was_invited(shortstatehash, user_id)
self.user_was_invited(shortstatehash, user_id).await
},
HistoryVisibility::Joined => {
// Allow if any member on requested server was joined, else deny
self.user_was_joined(shortstatehash, user_id)
self.user_was_joined(shortstatehash, user_id).await
},
_ => {
error!("Unknown history visibility {history_visibility}");
@ -261,17 +277,18 @@ impl Service {
.unwrap()
.insert((user_id.to_owned(), shortstatehash), visibility);
Ok(visibility)
visibility
}
/// Whether a user is allowed to see an event, based on
/// the room's history_visibility at that event's state.
#[tracing::instrument(skip(self, user_id, room_id))]
pub fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let currently_member = self.services.state_cache.is_joined(user_id, room_id)?;
pub async fn user_can_see_state_events(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let currently_member = self.services.state_cache.is_joined(user_id, room_id).await;
let history_visibility = self
.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?
.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")
.await
.map_or(Ok(HistoryVisibility::Shared), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility)
@ -285,11 +302,13 @@ impl Service {
})
.unwrap_or(HistoryVisibility::Shared);
Ok(currently_member || history_visibility == HistoryVisibility::WorldReadable)
currently_member || history_visibility == HistoryVisibility::WorldReadable
}
/// Returns the state hash for this pdu.
pub fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<Option<u64>> { self.db.pdu_shortstatehash(event_id) }
pub async fn pdu_shortstatehash(&self, event_id: &EventId) -> Result<u64> {
self.db.pdu_shortstatehash(event_id).await
}
/// Returns the full room state.
#[tracing::instrument(skip(self), level = "debug")]
@ -300,47 +319,61 @@ impl Service {
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_state_get_id(
pub async fn room_state_get_id(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<EventId>>> {
self.db.room_state_get_id(room_id, event_type, state_key)
) -> Result<Arc<EventId>> {
self.db
.room_state_get_id(room_id, event_type, state_key)
.await
}
/// Returns a single PDU from `room_id` with key (`event_type`,
/// `state_key`).
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_state_get(
pub async fn room_state_get(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<Option<Arc<PduEvent>>> {
self.db.room_state_get(room_id, event_type, state_key)
) -> Result<Arc<PduEvent>> {
self.db.room_state_get(room_id, event_type, state_key).await
}
pub fn get_name(&self, room_id: &RoomId) -> Result<Option<String>> {
self.room_state_get(room_id, &StateEventType::RoomName, "")?
.map_or(Ok(None), |s| {
Ok(serde_json::from_str(s.content.get()).map_or_else(|_| None, |c: RoomNameEventContent| Some(c.name)))
})
/// Returns a single PDU from `room_id` with key (`event_type`,`state_key`).
pub async fn room_state_get_content<T>(
&self, room_id: &RoomId, event_type: &StateEventType, state_key: &str,
) -> Result<T>
where
T: for<'de> Deserialize<'de> + Send,
{
use serde_json::from_str;
self.room_state_get(room_id, event_type, state_key)
.await
.and_then(|event| from_str::<T>(event.content.get()).map_err(Into::into))
}
pub fn get_avatar(&self, room_id: &RoomId) -> Result<ruma::JsOption<RoomAvatarEventContent>> {
self.room_state_get(room_id, &StateEventType::RoomAvatar, "")?
.map_or(Ok(ruma::JsOption::Undefined), |s| {
pub async fn get_name(&self, room_id: &RoomId) -> Result<String> {
self.room_state_get_content(room_id, &StateEventType::RoomName, "")
.await
.map(|c: RoomNameEventContent| c.name)
}
pub async fn get_avatar(&self, room_id: &RoomId) -> ruma::JsOption<RoomAvatarEventContent> {
self.room_state_get(room_id, &StateEventType::RoomAvatar, "")
.await
.map_or(ruma::JsOption::Undefined, |s| {
serde_json::from_str(s.content.get())
.map_err(|_| Error::bad_database("Invalid room avatar event in database."))
.unwrap()
})
}
pub fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<RoomMemberEventContent>> {
self.room_state_get(room_id, &StateEventType::RoomMember, user_id.as_str())?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map_err(|_| Error::bad_database("Invalid room member event in database."))
})
pub async fn get_member(&self, room_id: &RoomId, user_id: &UserId) -> Result<RoomMemberEventContent> {
self.room_state_get_content(room_id, &StateEventType::RoomMember, user_id.as_str())
.await
}
pub fn user_can_invite(
pub async fn user_can_invite(
&self, room_id: &RoomId, sender: &UserId, target_user: &UserId, state_lock: &RoomMutexGuard,
) -> Result<bool> {
) -> bool {
let content = to_raw_value(&RoomMemberEventContent::new(MembershipState::Invite))
.expect("Event content always serializes");
@ -353,122 +386,101 @@ impl Service {
timestamp: None,
};
Ok(self
.services
self.services
.timeline
.create_hash_and_sign_event(new_event, sender, room_id, state_lock)
.is_ok())
.await
.is_ok()
}
/// Checks if guests are able to view room content without joining
pub fn is_world_readable(&self, room_id: &RoomId) -> Result<bool, Error> {
self.room_state_get(room_id, &StateEventType::RoomHistoryVisibility, "")?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomHistoryVisibilityEventContent| {
c.history_visibility == HistoryVisibility::WorldReadable
})
.map_err(|e| {
error!(
"Invalid room history visibility event in database for room {room_id}, assuming not world \
readable: {e} "
);
Error::bad_database("Invalid room history visibility event in database.")
})
})
pub async fn is_world_readable(&self, room_id: &RoomId) -> bool {
self.room_state_get_content(room_id, &StateEventType::RoomHistoryVisibility, "")
.await
.map(|c: RoomHistoryVisibilityEventContent| c.history_visibility == HistoryVisibility::WorldReadable)
.unwrap_or(false)
}
/// Checks if guests are able to join a given room
pub fn guest_can_join(&self, room_id: &RoomId) -> Result<bool, Error> {
self.room_state_get(room_id, &StateEventType::RoomGuestAccess, "")?
.map_or(Ok(false), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin)
.map_err(|_| Error::bad_database("Invalid room guest access event in database."))
})
pub async fn guest_can_join(&self, room_id: &RoomId) -> bool {
self.room_state_get_content(room_id, &StateEventType::RoomGuestAccess, "")
.await
.map(|c: RoomGuestAccessEventContent| c.guest_access == GuestAccess::CanJoin)
.unwrap_or(false)
}
/// Gets the primary alias from canonical alias event
pub fn get_canonical_alias(&self, room_id: &RoomId) -> Result<Option<OwnedRoomAliasId>, Error> {
self.room_state_get(room_id, &StateEventType::RoomCanonicalAlias, "")?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomCanonicalAliasEventContent| c.alias)
.map_err(|_| Error::bad_database("Invalid canonical alias event in database."))
pub async fn get_canonical_alias(&self, room_id: &RoomId) -> Result<OwnedRoomAliasId> {
self.room_state_get_content(room_id, &StateEventType::RoomCanonicalAlias, "")
.await
.and_then(|c: RoomCanonicalAliasEventContent| {
c.alias
.ok_or_else(|| err!(Request(NotFound("No alias found in event content."))))
})
}
/// Gets the room topic
pub fn get_room_topic(&self, room_id: &RoomId) -> Result<Option<String>, Error> {
self.room_state_get(room_id, &StateEventType::RoomTopic, "")?
.map_or(Ok(None), |s| {
serde_json::from_str(s.content.get())
.map(|c: RoomTopicEventContent| Some(c.topic))
.map_err(|e| {
error!("Invalid room topic event in database for room {room_id}: {e}");
Error::bad_database("Invalid room topic event in database.")
})
})
pub async fn get_room_topic(&self, room_id: &RoomId) -> Result<String> {
self.room_state_get_content(room_id, &StateEventType::RoomTopic, "")
.await
.map(|c: RoomTopicEventContent| c.topic)
}
/// Checks if a given user can redact a given event
///
/// If federation is true, it allows redaction events from any user of the
/// same server as the original event sender
pub fn user_can_redact(
pub async fn user_can_redact(
&self, redacts: &EventId, sender: &UserId, room_id: &RoomId, federation: bool,
) -> Result<bool> {
self.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
.map_or_else(
|| {
// Falling back on m.room.create to judge power level
if let Some(pdu) = self.room_state_get(room_id, &StateEventType::RoomCreate, "")? {
Ok(pdu.sender == sender
|| if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) {
pdu.sender == sender
} else {
false
})
if let Ok(event) = self
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")
.await
{
let Ok(event) = serde_json::from_str(event.content.get())
.map(|content: RoomPowerLevelsEventContent| content.into())
.map(|event: RoomPowerLevels| event)
else {
return Ok(false);
};
Ok(event.user_can_redact_event_of_other(sender)
|| event.user_can_redact_own_event(sender)
&& if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await {
if federation {
pdu.sender.server_name() == sender.server_name()
} else {
pdu.sender == sender
}
} else {
Err(Error::bad_database(
"No m.room.power_levels or m.room.create events in database for room",
))
}
},
|event| {
serde_json::from_str(event.content.get())
.map(|content: RoomPowerLevelsEventContent| content.into())
.map(|event: RoomPowerLevels| {
event.user_can_redact_event_of_other(sender)
|| event.user_can_redact_own_event(sender)
&& if let Ok(Some(pdu)) = self.services.timeline.get_pdu(redacts) {
if federation {
pdu.sender.server_name() == sender.server_name()
} else {
pdu.sender == sender
}
} else {
false
}
})
.map_err(|_| Error::bad_database("Invalid m.room.power_levels event in database"))
},
)
false
})
} else {
// Falling back on m.room.create to judge power level
if let Ok(pdu) = self
.room_state_get(room_id, &StateEventType::RoomCreate, "")
.await
{
Ok(pdu.sender == sender
|| if let Ok(pdu) = self.services.timeline.get_pdu(redacts).await {
pdu.sender == sender
} else {
false
})
} else {
Err(Error::bad_database(
"No m.room.power_levels or m.room.create events in database for room",
))
}
}
}
/// Returns the join rule (`SpaceRoomJoinRule`) for a given room
pub fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>), Error> {
Ok(self
.room_state_get(room_id, &StateEventType::RoomJoinRules, "")?
.map(|s| {
serde_json::from_str(s.content.get())
.map(|c: RoomJoinRulesEventContent| {
(c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule))
})
.map_err(|e| err!(Database(error!("Invalid room join rule event in database: {e}"))))
})
.transpose()?
.unwrap_or((SpaceRoomJoinRule::Invite, vec![])))
pub async fn get_join_rule(&self, room_id: &RoomId) -> Result<(SpaceRoomJoinRule, Vec<OwnedRoomId>)> {
self.room_state_get_content(room_id, &StateEventType::RoomJoinRules, "")
.await
.map(|c: RoomJoinRulesEventContent| (c.join_rule.clone().into(), self.allowed_room_ids(c.join_rule)))
.or_else(|_| Ok((SpaceRoomJoinRule::Invite, vec![])))
}
/// Returns an empty vec if not a restricted room
@ -487,25 +499,21 @@ impl Service {
room_ids
}
pub fn get_room_type(&self, room_id: &RoomId) -> Result<Option<RoomType>> {
Ok(self
.room_state_get(room_id, &StateEventType::RoomCreate, "")?
.map(|s| {
serde_json::from_str::<RoomCreateEventContent>(s.content.get())
.map_err(|e| err!(Database(error!("Invalid room create event in database: {e}"))))
pub async fn get_room_type(&self, room_id: &RoomId) -> Result<RoomType> {
self.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
.await
.and_then(|content: RoomCreateEventContent| {
content
.room_type
.ok_or_else(|| err!(Request(NotFound("No type found in event content"))))
})
.transpose()?
.and_then(|e| e.room_type))
}
/// Gets the room's encryption algorithm if `m.room.encryption` state event
/// is found
pub fn get_room_encryption(&self, room_id: &RoomId) -> Result<Option<EventEncryptionAlgorithm>> {
self.room_state_get(room_id, &StateEventType::RoomEncryption, "")?
.map_or(Ok(None), |s| {
serde_json::from_str::<RoomEncryptionEventContent>(s.content.get())
.map(|content| Some(content.algorithm))
.map_err(|e| err!(Database(error!("Invalid room encryption event in database: {e}"))))
})
pub async fn get_room_encryption(&self, room_id: &RoomId) -> Result<EventEncryptionAlgorithm> {
self.room_state_get_content(room_id, &StateEventType::RoomEncryption, "")
.await
.map(|content: RoomEncryptionEventContent| content.algorithm)
}
}

View file

@ -1,43 +1,42 @@
use std::{
collections::{HashMap, HashSet},
collections::HashMap,
sync::{Arc, RwLock},
};
use conduit::{utils, Error, Result};
use database::Map;
use itertools::Itertools;
use conduit::{utils, utils::stream::TryIgnore, Error, Result};
use database::{Deserialized, Interfix, Map};
use futures::{Stream, StreamExt};
use ruma::{
events::{AnyStrippedStateEvent, AnySyncStateEvent},
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
OwnedRoomId, RoomId, UserId,
};
use crate::{appservice::RegistrationInfo, globals, users, Dep};
use crate::{globals, Dep};
type StrippedStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + 'a>;
type AnySyncStateEventIter<'a> = Box<dyn Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + 'a>;
type AppServiceInRoomCache = RwLock<HashMap<OwnedRoomId, HashMap<String, bool>>>;
type StrippedStateEventItem = (OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>);
type SyncStateEventItem = (OwnedRoomId, Vec<Raw<AnySyncStateEvent>>);
pub(super) struct Data {
pub(super) appservice_in_room_cache: AppServiceInRoomCache,
roomid_invitedcount: Arc<Map>,
roomid_inviteviaservers: Arc<Map>,
roomid_joinedcount: Arc<Map>,
roomserverids: Arc<Map>,
roomuserid_invitecount: Arc<Map>,
roomuserid_joined: Arc<Map>,
roomuserid_leftcount: Arc<Map>,
roomuseroncejoinedids: Arc<Map>,
serverroomids: Arc<Map>,
userroomid_invitestate: Arc<Map>,
userroomid_joined: Arc<Map>,
userroomid_leftstate: Arc<Map>,
pub(super) roomid_invitedcount: Arc<Map>,
pub(super) roomid_inviteviaservers: Arc<Map>,
pub(super) roomid_joinedcount: Arc<Map>,
pub(super) roomserverids: Arc<Map>,
pub(super) roomuserid_invitecount: Arc<Map>,
pub(super) roomuserid_joined: Arc<Map>,
pub(super) roomuserid_leftcount: Arc<Map>,
pub(super) roomuseroncejoinedids: Arc<Map>,
pub(super) serverroomids: Arc<Map>,
pub(super) userroomid_invitestate: Arc<Map>,
pub(super) userroomid_joined: Arc<Map>,
pub(super) userroomid_leftstate: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
users: Dep<users::Service>,
}
impl Data {
@ -59,19 +58,18 @@ impl Data {
userroomid_leftstate: db["userroomid_leftstate"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
users: args.depend::<users::Service>("users"),
},
}
}
pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
pub(super) fn mark_as_once_joined(&self, user_id: &UserId, room_id: &RoomId) {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.roomuseroncejoinedids.insert(&userroom_id, &[])
self.roomuseroncejoinedids.insert(&userroom_id, &[]);
}
pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
pub(super) fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) {
let roomid = room_id.as_bytes().to_vec();
let mut roomuser_id = roomid.clone();
@ -82,64 +80,17 @@ impl Data {
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_joined.insert(&userroom_id, &[])?;
self.roomuserid_joined.insert(&roomuser_id, &[])?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
self.userroomid_joined.insert(&userroom_id, &[]);
self.roomuserid_joined.insert(&roomuser_id, &[]);
self.userroomid_invitestate.remove(&userroom_id);
self.roomuserid_invitecount.remove(&roomuser_id);
self.userroomid_leftstate.remove(&userroom_id);
self.roomuserid_leftcount.remove(&roomuser_id);
self.roomid_inviteviaservers.remove(&roomid)?;
Ok(())
self.roomid_inviteviaservers.remove(&roomid);
}
pub(super) fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) -> Result<()> {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.userroomid_invitestate.insert(
&userroom_id,
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
)?;
self.roomuserid_invitecount
.insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
if let Some(servers) = invite_via {
let mut prev_servers = self
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect_vec();
#[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
}
Ok(())
}
pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
pub(super) fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) {
let roomid = room_id.as_bytes().to_vec();
let mut roomuser_id = roomid.clone();
@ -153,115 +104,20 @@ impl Data {
self.userroomid_leftstate.insert(
&userroom_id,
&serde_json::to_vec(&Vec::<Raw<AnySyncStateEvent>>::new()).unwrap(),
)?; // TODO
); // TODO
self.roomuserid_leftcount
.insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
self.userroomid_joined.remove(&userroom_id)?;
self.roomuserid_joined.remove(&roomuser_id)?;
self.userroomid_invitestate.remove(&userroom_id)?;
self.roomuserid_invitecount.remove(&roomuser_id)?;
.insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes());
self.userroomid_joined.remove(&userroom_id);
self.roomuserid_joined.remove(&roomuser_id);
self.userroomid_invitestate.remove(&userroom_id);
self.roomuserid_invitecount.remove(&roomuser_id);
self.roomid_inviteviaservers.remove(&roomid)?;
Ok(())
}
pub(super) fn update_joined_count(&self, room_id: &RoomId) -> Result<()> {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut joined_servers = HashSet::new();
for joined in self.room_members(room_id).filter_map(Result::ok) {
joined_servers.insert(joined.server_name().to_owned());
joinedcount = joinedcount.saturating_add(1);
}
for _invited in self.room_members_invited(room_id).filter_map(Result::ok) {
invitedcount = invitedcount.saturating_add(1);
}
self.roomid_joinedcount
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes())?;
self.roomid_invitedcount
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes())?;
for old_joined_server in self.room_servers(room_id).filter_map(Result::ok) {
if !joined_servers.remove(&old_joined_server) {
// Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.remove(&roomserver_id)?;
self.serverroomids.remove(&serverroom_id)?;
}
}
// Now only new servers are in joined_servers anymore
for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.roomserverids.insert(&roomserver_id, &[])?;
self.serverroomids.insert(&serverroom_id, &[])?;
}
self.appservice_in_room_cache
.write()
.unwrap()
.remove(room_id);
Ok(())
}
#[tracing::instrument(skip(self, room_id, appservice), level = "debug")]
pub(super) fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> {
let maybe = self
.appservice_in_room_cache
.read()
.unwrap()
.get(room_id)
.and_then(|map| map.get(&appservice.registration.id))
.copied();
if let Some(b) = maybe {
Ok(b)
} else {
let bridge_user_id = UserId::parse_with_server_name(
appservice.registration.sender_localpart.as_str(),
self.services.globals.server_name(),
)
.ok();
let in_room = bridge_user_id.map_or(false, |id| self.is_joined(&id, room_id).unwrap_or(false))
|| self
.room_members(room_id)
.any(|userid| userid.map_or(false, |userid| appservice.users.is_match(userid.as_str())));
self.appservice_in_room_cache
.write()
.unwrap()
.entry(room_id.to_owned())
.or_default()
.insert(appservice.registration.id.clone(), in_room);
Ok(in_room)
}
self.roomid_inviteviaservers.remove(&roomid);
}
/// Makes a user forget a room.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> {
pub(super) fn forget(&self, room_id: &RoomId, user_id: &UserId) {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
@ -270,397 +126,69 @@ impl Data {
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_leftstate.remove(&userroom_id)?;
self.roomuserid_leftcount.remove(&roomuser_id)?;
Ok(())
}
/// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_servers<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomserverids.scan_prefix(prefix).map(|(key, _)| {
ServerName::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Server name in roomserverids is invalid."))
}))
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
let mut key = server.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
self.serverroomids.get(&key).map(|o| o.is_some())
}
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn server_rooms<'a>(
&'a self, server: &ServerName,
) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a> {
let mut prefix = server.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.serverroomids.scan_prefix(prefix).map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("RoomId in serverroomids is invalid."))
}))
}
/// Returns an iterator of all joined members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_members<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + Send + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(self.roomuserid_joined.scan_prefix(prefix).map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_joined is invalid."))
}))
}
/// Returns an iterator of all our local users in the room, even if they're
/// deactivated/guests
pub(super) fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> {
Box::new(
self.room_members(room_id)
.filter_map(Result::ok)
.filter(|user| self.services.globals.user_is_local(user)),
)
}
/// Returns an iterator of all our local joined users in a room who are
/// active (not deactivated, not guest)
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn active_local_users_in_room<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = OwnedUserId> + 'a> {
Box::new(
self.local_users_in_room(room_id)
.filter(|user| !self.services.users.is_deactivated(user).unwrap_or(true)),
)
}
/// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_joinedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns the number of users which are currently invited to a room
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> {
self.roomid_invitedcount
.get(room_id.as_bytes())?
.map(|b| utils::u64_from_bytes(&b).map_err(|_| Error::bad_database("Invalid joinedcount in db.")))
.transpose()
}
/// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_useroncejoined<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuseroncejoinedids
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in room_useroncejoined is invalid."))
}),
)
}
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn room_members_invited<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedUserId>> + 'a> {
let mut prefix = room_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.roomuserid_invitecount
.scan_prefix(prefix)
.map(|(key, _)| {
UserId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("User ID in roomuserid_invited is invalid."))
}),
)
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_invitecount
.get(&key)?
.map_or(Ok(None), |bytes| {
Ok(Some(
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid invitecount in db."))?,
))
})
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
self.roomuserid_leftcount
.get(&key)?
.map(|bytes| utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid leftcount in db.")))
.transpose()
}
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn rooms_joined(&self, user_id: &UserId) -> Box<dyn Iterator<Item = Result<OwnedRoomId>> + '_> {
Box::new(
self.userroomid_joined
.scan_prefix(user_id.as_bytes().to_vec())
.map(|(key, _)| {
RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_joined is invalid."))
}),
)
self.userroomid_leftstate.remove(&userroom_id);
self.roomuserid_leftcount.remove(&roomuser_id);
}
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn rooms_invited<'a>(&'a self, user_id: &UserId) -> StrippedStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
Box::new(
self.userroomid_invitestate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn invite_state(
&self, user_id: &UserId, room_id: &RoomId,
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
#[inline]
pub(super) fn rooms_invited<'a>(
&'a self, user_id: &'a UserId,
) -> impl Stream<Item = StrippedStateEventItem> + Send + 'a {
let prefix = (user_id, Interfix);
self.userroomid_invitestate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))?;
.stream_raw_prefix(&prefix)
.ignore_err()
.map(|(key, val)| {
let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap();
let room_id = utils::string_from_bytes(room_id).unwrap();
let room_id = RoomId::parse(room_id).unwrap();
let state = serde_json::from_slice(val)
.map_err(|_| Error::bad_database("Invalid state in userroomid_invitestate."))
.unwrap();
Ok(state)
(room_id, state)
})
.transpose()
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn left_state(
pub(super) async fn invite_state(
&self, user_id: &UserId, room_id: &RoomId,
) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
let mut key = user_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(room_id.as_bytes());
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let key = (user_id, room_id);
self.userroomid_invitestate
.qry(&key)
.await
.deserialized_json()
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) async fn left_state(
&self, user_id: &UserId, room_id: &RoomId,
) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
let key = (user_id, room_id);
self.userroomid_leftstate
.get(&key)?
.map(|state| {
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok(state)
})
.transpose()
.qry(&key)
.await
.deserialized_json()
}
/// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn rooms_left<'a>(&'a self, user_id: &UserId) -> AnySyncStateEventIter<'a> {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
#[inline]
pub(super) fn rooms_left<'a>(&'a self, user_id: &'a UserId) -> impl Stream<Item = SyncStateEventItem> + Send + 'a {
let prefix = (user_id, Interfix);
self.userroomid_leftstate
.stream_raw_prefix(&prefix)
.ignore_err()
.map(|(key, val)| {
let room_id = key.rsplit(|&b| b == 0xFF).next().unwrap();
let room_id = utils::string_from_bytes(room_id).unwrap();
let room_id = RoomId::parse(room_id).unwrap();
let state = serde_json::from_slice(val)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))
.unwrap();
Box::new(
self.userroomid_leftstate
.scan_prefix(prefix)
.map(|(key, state)| {
let room_id = RoomId::parse(
utils::string_from_bytes(
key.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid unicode."))?,
)
.map_err(|_| Error::bad_database("Room ID in userroomid_invited is invalid."))?;
let state = serde_json::from_slice(&state)
.map_err(|_| Error::bad_database("Invalid state in userroomid_leftstate."))?;
Ok((room_id, state))
}),
)
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.roomuseroncejoinedids.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_joined.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_invitestate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
Ok(self.userroomid_leftstate.get(&userroom_id)?.is_some())
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn servers_invite_via<'a>(
&'a self, room_id: &RoomId,
) -> Box<dyn Iterator<Item = Result<OwnedServerName>> + 'a> {
let key = room_id.as_bytes().to_vec();
Box::new(
self.roomid_inviteviaservers
.scan_prefix(key)
.map(|(_, servers)| {
ServerName::parse(
utils::string_from_bytes(
servers
.rsplit(|&b| b == 0xFF)
.next()
.expect("rsplit always returns an element"),
)
.map_err(|_| {
Error::bad_database("Server name in roomid_inviteviaservers is invalid unicode.")
})?,
)
.map_err(|_| Error::bad_database("Server name in roomid_inviteviaservers is invalid."))
}),
)
}
#[tracing::instrument(skip(self), level = "debug")]
pub(super) fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) -> Result<()> {
let mut prev_servers = self
.servers_invite_via(room_id)
.filter_map(Result::ok)
.collect_vec();
prev_servers.extend(servers.to_owned());
prev_servers.sort_unstable();
prev_servers.dedup();
let servers = prev_servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers)?;
Ok(())
(room_id, state)
})
}
}

View file

@ -1,9 +1,15 @@
mod data;
use std::sync::Arc;
use std::{collections::HashSet, sync::Arc};
use conduit::{err, error, warn, Error, Result};
use conduit::{
err,
utils::{stream::TryIgnore, ReadyExt},
warn, Result,
};
use data::Data;
use database::{Deserialized, Ignore, Interfix};
use futures::{Stream, StreamExt};
use itertools::Itertools;
use ruma::{
events::{
@ -18,7 +24,7 @@ use ruma::{
},
int,
serde::Raw,
OwnedRoomId, OwnedServerName, OwnedUserId, RoomId, ServerName, UserId,
OwnedRoomId, OwnedServerName, RoomId, ServerName, UserId,
};
use crate::{account_data, appservice::RegistrationInfo, globals, rooms, users, Dep};
@ -55,7 +61,7 @@ impl Service {
/// Update current membership data.
#[tracing::instrument(skip(self, last_state))]
#[allow(clippy::too_many_arguments)]
pub fn update_membership(
pub async fn update_membership(
&self, room_id: &RoomId, user_id: &UserId, membership_event: RoomMemberEventContent, sender: &UserId,
last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>, invite_via: Option<Vec<OwnedServerName>>,
update_joined_count: bool,
@ -68,7 +74,7 @@ impl Service {
// update
#[allow(clippy::collapsible_if)]
if !self.services.globals.user_is_local(user_id) {
if !self.services.users.exists(user_id)? {
if !self.services.users.exists(user_id).await {
self.services.users.create(user_id, None)?;
}
@ -100,17 +106,17 @@ impl Service {
match &membership {
MembershipState::Join => {
// Check if the user never joined this room
if !self.once_joined(user_id, room_id)? {
if !self.once_joined(user_id, room_id).await {
// Add the user ID to the join list then
self.db.mark_as_once_joined(user_id, room_id)?;
self.db.mark_as_once_joined(user_id, room_id);
// Check if the room has a predecessor
if let Some(predecessor) = self
if let Ok(Some(predecessor)) = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomCreate, "")?
.and_then(|create| serde_json::from_str(create.content.get()).ok())
.and_then(|content: RoomCreateEventContent| content.predecessor)
.room_state_get_content(room_id, &StateEventType::RoomCreate, "")
.await
.map(|content: RoomCreateEventContent| content.predecessor)
{
// Copy user settings from predecessor to the current room:
// - Push rules
@ -138,32 +144,33 @@ impl Service {
// .ok();
// Copy old tags to new room
if let Some(tag_event) = self
if let Ok(tag_event) = self
.services
.account_data
.get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)?
.map(|event| {
.get(Some(&predecessor.room_id), user_id, RoomAccountDataEventType::Tag)
.await
.and_then(|event| {
serde_json::from_str(event.get())
.map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}"))))
}) {
self.services
.account_data
.update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event?)
.update(Some(room_id), user_id, RoomAccountDataEventType::Tag, &tag_event)
.await
.ok();
};
// Copy direct chat flag
if let Some(direct_event) = self
if let Ok(mut direct_event) = self
.services
.account_data
.get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())?
.map(|event| {
.get(None, user_id, GlobalAccountDataEventType::Direct.to_string().into())
.await
.and_then(|event| {
serde_json::from_str::<DirectEvent>(event.get())
.map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}"))))
}) {
let mut direct_event = direct_event?;
let mut room_ids_updated = false;
for room_ids in direct_event.content.0.values_mut() {
if room_ids.iter().any(|r| r == &predecessor.room_id) {
room_ids.push(room_id.to_owned());
@ -172,18 +179,21 @@ impl Service {
}
if room_ids_updated {
self.services.account_data.update(
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
&serde_json::to_value(&direct_event).expect("to json always works"),
)?;
self.services
.account_data
.update(
None,
user_id,
GlobalAccountDataEventType::Direct.to_string().into(),
&serde_json::to_value(&direct_event).expect("to json always works"),
)
.await?;
}
};
}
}
self.db.mark_as_joined(user_id, room_id)?;
self.db.mark_as_joined(user_id, room_id);
},
MembershipState::Invite => {
// We want to know if the sender is ignored by the receiver
@ -196,12 +206,12 @@ impl Service {
GlobalAccountDataEventType::IgnoredUserList
.to_string()
.into(),
)?
.map(|event| {
)
.await
.and_then(|event| {
serde_json::from_str::<IgnoredUserListEvent>(event.get())
.map_err(|e| err!(Database(warn!("Invalid account data event in db: {e:?}"))))
})
.transpose()?
.map_or(false, |ignored| {
ignored
.content
@ -214,194 +224,282 @@ impl Service {
return Ok(());
}
self.db
.mark_as_invited(user_id, room_id, last_state, invite_via)?;
self.mark_as_invited(user_id, room_id, last_state, invite_via)
.await;
},
MembershipState::Leave | MembershipState::Ban => {
self.db.mark_as_left(user_id, room_id)?;
self.db.mark_as_left(user_id, room_id);
},
_ => {},
}
if update_joined_count {
self.update_joined_count(room_id)?;
self.update_joined_count(room_id).await;
}
Ok(())
}
#[tracing::instrument(skip(self, room_id), level = "debug")]
pub fn update_joined_count(&self, room_id: &RoomId) -> Result<()> { self.db.update_joined_count(room_id) }
#[tracing::instrument(skip(self, room_id, appservice), level = "debug")]
pub fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> Result<bool> {
self.db.appservice_in_room(room_id, appservice)
pub async fn appservice_in_room(&self, room_id: &RoomId, appservice: &RegistrationInfo) -> bool {
let maybe = self
.db
.appservice_in_room_cache
.read()
.unwrap()
.get(room_id)
.and_then(|map| map.get(&appservice.registration.id))
.copied();
if let Some(b) = maybe {
b
} else {
let bridge_user_id = UserId::parse_with_server_name(
appservice.registration.sender_localpart.as_str(),
self.services.globals.server_name(),
)
.ok();
let in_room = if let Some(id) = &bridge_user_id {
self.is_joined(id, room_id).await
} else {
false
};
let in_room = in_room
|| self
.room_members(room_id)
.ready_any(|userid| appservice.users.is_match(userid.as_str()))
.await;
self.db
.appservice_in_room_cache
.write()
.unwrap()
.entry(room_id.to_owned())
.or_default()
.insert(appservice.registration.id.clone(), in_room);
in_room
}
}
/// Direct DB function to directly mark a user as left. It is not
/// recommended to use this directly. You most likely should use
/// `update_membership` instead
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
self.db.mark_as_left(user_id, room_id)
}
pub fn mark_as_left(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_left(user_id, room_id); }
/// Direct DB function to directly mark a user as joined. It is not
/// recommended to use this directly. You most likely should use
/// `update_membership` instead
#[tracing::instrument(skip(self), level = "debug")]
pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
self.db.mark_as_joined(user_id, room_id)
}
pub fn mark_as_joined(&self, user_id: &UserId, room_id: &RoomId) { self.db.mark_as_joined(user_id, room_id); }
/// Makes a user forget a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn forget(&self, room_id: &RoomId, user_id: &UserId) -> Result<()> { self.db.forget(room_id, user_id) }
pub fn forget(&self, room_id: &RoomId, user_id: &UserId) { self.db.forget(room_id, user_id); }
/// Returns an iterator of all servers participating in this room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_servers(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ {
self.db.room_servers(room_id)
pub fn room_servers<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &ServerName> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomserverids
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, server): (Ignore, &ServerName)| server)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn server_in_room(&self, server: &ServerName, room_id: &RoomId) -> Result<bool> {
self.db.server_in_room(server, room_id)
pub async fn server_in_room<'a>(&'a self, server: &'a ServerName, room_id: &'a RoomId) -> bool {
let key = (server, room_id);
self.db.serverroomids.qry(&key).await.is_ok()
}
/// Returns an iterator of all rooms a server participates in (as far as we
/// know).
#[tracing::instrument(skip(self), level = "debug")]
pub fn server_rooms(&self, server: &ServerName) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ {
self.db.server_rooms(server)
pub fn server_rooms<'a>(&'a self, server: &'a ServerName) -> impl Stream<Item = &RoomId> + Send + 'a {
let prefix = (server, Interfix);
self.db
.serverroomids
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, room_id): (Ignore, &RoomId)| room_id)
}
/// Returns true if server can see user by sharing at least one room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> Result<bool> {
Ok(self
.server_rooms(server)
.filter_map(Result::ok)
.any(|room_id: OwnedRoomId| self.is_joined(user_id, &room_id).unwrap_or(false)))
pub async fn server_sees_user(&self, server: &ServerName, user_id: &UserId) -> bool {
self.server_rooms(server)
.any(|room_id| self.is_joined(user_id, room_id))
.await
}
/// Returns true if user_a and user_b share at least one room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> Result<bool> {
pub async fn user_sees_user(&self, user_a: &UserId, user_b: &UserId) -> bool {
// Minimize number of point-queries by iterating user with least nr rooms
let (a, b) = if self.rooms_joined(user_a).count() < self.rooms_joined(user_b).count() {
let (a, b) = if self.rooms_joined(user_a).count().await < self.rooms_joined(user_b).count().await {
(user_a, user_b)
} else {
(user_b, user_a)
};
Ok(self
.rooms_joined(a)
.filter_map(Result::ok)
.any(|room_id| self.is_joined(b, &room_id).unwrap_or(false)))
self.rooms_joined(a)
.any(|room_id| self.is_joined(b, room_id))
.await
}
/// Returns an iterator over all joined members of a room.
/// Returns an iterator of all joined members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_members(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + Send + '_ {
self.db.room_members(room_id)
pub fn room_members<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_joined
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
/// Returns the number of users which are currently in a room
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_joined_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.room_joined_count(room_id) }
pub async fn room_joined_count(&self, room_id: &RoomId) -> Result<u64> {
self.db.roomid_joinedcount.qry(room_id).await.deserialized()
}
#[tracing::instrument(skip(self), level = "debug")]
/// Returns an iterator of all our local users in the room, even if they're
/// deactivated/guests
pub fn local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = OwnedUserId> + 'a {
self.db.local_users_in_room(room_id)
pub fn local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
self.room_members(room_id)
.ready_filter(|user| self.services.globals.user_is_local(user))
}
#[tracing::instrument(skip(self), level = "debug")]
/// Returns an iterator of all our local joined users in a room who are
/// active (not deactivated, not guest)
pub fn active_local_users_in_room<'a>(&'a self, room_id: &RoomId) -> impl Iterator<Item = OwnedUserId> + 'a {
self.db.active_local_users_in_room(room_id)
pub fn active_local_users_in_room<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
self.local_users_in_room(room_id)
.filter(|user| self.services.users.is_active(user))
}
/// Returns the number of users which are currently invited to a room
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_invited_count(&self, room_id: &RoomId) -> Result<Option<u64>> { self.db.room_invited_count(room_id) }
pub async fn room_invited_count(&self, room_id: &RoomId) -> Result<u64> {
self.db
.roomid_invitedcount
.qry(room_id)
.await
.deserialized()
}
/// Returns an iterator over all User IDs who ever joined a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_useroncejoined(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ {
self.db.room_useroncejoined(room_id)
pub fn room_useroncejoined<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuseroncejoinedids
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
/// Returns an iterator over all invited members of a room.
#[tracing::instrument(skip(self), level = "debug")]
pub fn room_members_invited(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedUserId>> + '_ {
self.db.room_members_invited(room_id)
pub fn room_members_invited<'a>(&'a self, room_id: &'a RoomId) -> impl Stream<Item = &UserId> + Send + 'a {
let prefix = (room_id, Interfix);
self.db
.roomuserid_invitecount
.keys_prefix(&prefix)
.ignore_err()
.map(|(_, user_id): (Ignore, &UserId)| user_id)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
self.db.get_invite_count(room_id, user_id)
pub async fn get_invite_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.db
.roomuserid_invitecount
.qry(&key)
.await
.deserialized()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<Option<u64>> {
self.db.get_left_count(room_id, user_id)
pub async fn get_left_count(&self, room_id: &RoomId, user_id: &UserId) -> Result<u64> {
let key = (room_id, user_id);
self.db.roomuserid_leftcount.qry(&key).await.deserialized()
}
/// Returns an iterator over all rooms this user joined.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_joined(&self, user_id: &UserId) -> impl Iterator<Item = Result<OwnedRoomId>> + '_ {
self.db.rooms_joined(user_id)
pub fn rooms_joined(&self, user_id: &UserId) -> impl Stream<Item = &RoomId> + Send {
self.db
.userroomid_joined
.keys_prefix(user_id)
.ignore_err()
.map(|(_, room_id): (Ignore, &RoomId)| room_id)
}
/// Returns an iterator over all rooms a user was invited to.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_invited(
&self, user_id: &UserId,
) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)>> + '_ {
pub fn rooms_invited<'a>(
&'a self, user_id: &'a UserId,
) -> impl Stream<Item = (OwnedRoomId, Vec<Raw<AnyStrippedStateEvent>>)> + Send + 'a {
self.db.rooms_invited(user_id)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
self.db.invite_state(user_id, room_id)
pub async fn invite_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
self.db.invite_state(user_id, room_id).await
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Option<Vec<Raw<AnyStrippedStateEvent>>>> {
self.db.left_state(user_id, room_id)
pub async fn left_state(&self, user_id: &UserId, room_id: &RoomId) -> Result<Vec<Raw<AnyStrippedStateEvent>>> {
self.db.left_state(user_id, room_id).await
}
/// Returns an iterator over all rooms a user left.
#[tracing::instrument(skip(self), level = "debug")]
pub fn rooms_left(
&self, user_id: &UserId,
) -> impl Iterator<Item = Result<(OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)>> + '_ {
pub fn rooms_left<'a>(
&'a self, user_id: &'a UserId,
) -> impl Stream<Item = (OwnedRoomId, Vec<Raw<AnySyncStateEvent>>)> + Send + 'a {
self.db.rooms_left(user_id)
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
self.db.once_joined(user_id, room_id)
pub async fn once_joined(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.roomuseroncejoinedids.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_joined(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_joined(user_id, room_id) }
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> {
self.db.is_invited(user_id, room_id)
pub async fn is_joined<'a>(&'a self, user_id: &'a UserId, room_id: &'a RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_joined.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> Result<bool> { self.db.is_left(user_id, room_id) }
pub async fn is_invited(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_invitestate.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn servers_invite_via(&self, room_id: &RoomId) -> impl Iterator<Item = Result<OwnedServerName>> + '_ {
self.db.servers_invite_via(room_id)
pub async fn is_left(&self, user_id: &UserId, room_id: &RoomId) -> bool {
let key = (user_id, room_id);
self.db.userroomid_leftstate.qry(&key).await.is_ok()
}
#[tracing::instrument(skip(self), level = "debug")]
pub fn servers_invite_via<'a>(&'a self, room_id: &RoomId) -> impl Stream<Item = &ServerName> + Send + 'a {
self.db
.roomid_inviteviaservers
.stream_prefix(room_id)
.ignore_err()
.map(|(_, servers): (Ignore, Vec<&ServerName>)| &**(servers.last().expect("at least one servername")))
}
/// Gets up to three servers that are likely to be in the room in the
@ -409,37 +507,27 @@ impl Service {
///
/// See <https://spec.matrix.org/v1.10/appendices/#routing>
#[tracing::instrument(skip(self))]
pub fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> {
pub async fn servers_route_via(&self, room_id: &RoomId) -> Result<Vec<OwnedServerName>> {
let most_powerful_user_server = self
.services
.state_accessor
.room_state_get(room_id, &StateEventType::RoomPowerLevels, "")?
.map(|pdu| {
serde_json::from_str(pdu.content.get()).map(|conent: RoomPowerLevelsEventContent| {
conent
.users
.iter()
.max_by_key(|(_, power)| *power)
.and_then(|x| {
if x.1 >= &int!(50) {
Some(x)
} else {
None
}
})
.map(|(user, _power)| user.server_name().to_owned())
})
.room_state_get_content(room_id, &StateEventType::RoomPowerLevels, "")
.await
.map(|content: RoomPowerLevelsEventContent| {
content
.users
.iter()
.max_by_key(|(_, power)| *power)
.and_then(|x| (x.1 >= &int!(50)).then_some(x))
.map(|(user, _power)| user.server_name().to_owned())
})
.transpose()
.map_err(|e| {
error!("Invalid power levels event content in database: {e}");
Error::bad_database("Invalid power levels event content in database")
})?
.flatten();
.map_err(|e| err!(Database(error!(?e, "Invalid power levels event content in database."))))?;
let mut servers: Vec<OwnedServerName> = self
.room_members(room_id)
.filter_map(Result::ok)
.collect::<Vec<_>>()
.await
.iter()
.counts_by(|user| user.server_name().to_owned())
.iter()
.sorted_by_key(|(_, users)| *users)
@ -468,4 +556,139 @@ impl Service {
.expect("locked")
.clear();
}
pub async fn update_joined_count(&self, room_id: &RoomId) {
let mut joinedcount = 0_u64;
let mut invitedcount = 0_u64;
let mut joined_servers = HashSet::new();
self.room_members(room_id)
.ready_for_each(|joined| {
joined_servers.insert(joined.server_name().to_owned());
joinedcount = joinedcount.saturating_add(1);
})
.await;
invitedcount = invitedcount.saturating_add(
self.room_members_invited(room_id)
.count()
.await
.try_into()
.unwrap_or(0),
);
self.db
.roomid_joinedcount
.insert(room_id.as_bytes(), &joinedcount.to_be_bytes());
self.db
.roomid_invitedcount
.insert(room_id.as_bytes(), &invitedcount.to_be_bytes());
self.room_servers(room_id)
.ready_for_each(|old_joined_server| {
if !joined_servers.remove(old_joined_server) {
// Server not in room anymore
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(old_joined_server.as_bytes());
let mut serverroom_id = old_joined_server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.db.roomserverids.remove(&roomserver_id);
self.db.serverroomids.remove(&serverroom_id);
}
})
.await;
// Now only new servers are in joined_servers anymore
for server in joined_servers {
let mut roomserver_id = room_id.as_bytes().to_vec();
roomserver_id.push(0xFF);
roomserver_id.extend_from_slice(server.as_bytes());
let mut serverroom_id = server.as_bytes().to_vec();
serverroom_id.push(0xFF);
serverroom_id.extend_from_slice(room_id.as_bytes());
self.db.roomserverids.insert(&roomserver_id, &[]);
self.db.serverroomids.insert(&serverroom_id, &[]);
}
self.db
.appservice_in_room_cache
.write()
.unwrap()
.remove(room_id);
}
pub async fn mark_as_invited(
&self, user_id: &UserId, room_id: &RoomId, last_state: Option<Vec<Raw<AnyStrippedStateEvent>>>,
invite_via: Option<Vec<OwnedServerName>>,
) {
let mut roomuser_id = room_id.as_bytes().to_vec();
roomuser_id.push(0xFF);
roomuser_id.extend_from_slice(user_id.as_bytes());
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
self.db.userroomid_invitestate.insert(
&userroom_id,
&serde_json::to_vec(&last_state.unwrap_or_default()).expect("state to bytes always works"),
);
self.db
.roomuserid_invitecount
.insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes());
self.db.userroomid_joined.remove(&userroom_id);
self.db.roomuserid_joined.remove(&roomuser_id);
self.db.userroomid_leftstate.remove(&userroom_id);
self.db.roomuserid_leftcount.remove(&roomuser_id);
if let Some(servers) = invite_via {
let mut prev_servers = self
.servers_invite_via(room_id)
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await;
#[allow(clippy::redundant_clone)] // this is a necessary clone?
prev_servers.append(servers.clone().as_mut());
let servers = prev_servers.iter().rev().unique().rev().collect_vec();
let servers = servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.db
.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers);
}
}
#[tracing::instrument(skip(self), level = "debug")]
pub async fn add_servers_invite_via(&self, room_id: &RoomId, servers: &[OwnedServerName]) {
let mut prev_servers = self
.servers_invite_via(room_id)
.map(ToOwned::to_owned)
.collect::<Vec<_>>()
.await;
prev_servers.extend(servers.to_owned());
prev_servers.sort_unstable();
prev_servers.dedup();
let servers = prev_servers
.iter()
.map(|server| server.as_bytes())
.collect_vec()
.join(&[0xFF][..]);
self.db
.roomid_inviteviaservers
.insert(room_id.as_bytes(), &servers);
}
}

View file

@ -1,6 +1,6 @@
use std::{collections::HashSet, mem::size_of, sync::Arc};
use conduit::{checked, utils, Error, Result};
use conduit::{err, expected, utils, Result};
use database::{Database, Map};
use super::CompressedStateEvent;
@ -22,11 +22,13 @@ impl Data {
}
}
pub(super) fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
pub(super) async fn get_statediff(&self, shortstatehash: u64) -> Result<StateDiff> {
let value = self
.shortstatehash_statediff
.get(&shortstatehash.to_be_bytes())?
.ok_or_else(|| Error::bad_database("State hash does not exist"))?;
.qry(&shortstatehash)
.await
.map_err(|e| err!(Database("Failed to find StateDiff from short {shortstatehash:?}: {e}")))?;
let parent = utils::u64_from_bytes(&value[0..size_of::<u64>()]).expect("bytes have right length");
let parent = if parent != 0 {
Some(parent)
@ -40,10 +42,10 @@ impl Data {
let stride = size_of::<u64>();
let mut i = stride;
while let Some(v) = value.get(i..checked!(i + 2 * stride)?) {
while let Some(v) = value.get(i..expected!(i + 2 * stride)) {
if add_mode && v.starts_with(&0_u64.to_be_bytes()) {
add_mode = false;
i = checked!(i + stride)?;
i = expected!(i + stride);
continue;
}
if add_mode {
@ -51,7 +53,7 @@ impl Data {
} else {
removed.insert(v.try_into().expect("we checked the size above"));
}
i = checked!(i + 2 * stride)?;
i = expected!(i + 2 * stride);
}
Ok(StateDiff {
@ -61,7 +63,7 @@ impl Data {
})
}
pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) -> Result<()> {
pub(super) fn save_statediff(&self, shortstatehash: u64, diff: &StateDiff) {
let mut value = diff.parent.unwrap_or(0).to_be_bytes().to_vec();
for new in diff.added.iter() {
value.extend_from_slice(&new[..]);
@ -75,6 +77,6 @@ impl Data {
}
self.shortstatehash_statediff
.insert(&shortstatehash.to_be_bytes(), &value)
.insert(&shortstatehash.to_be_bytes(), &value);
}
}

View file

@ -27,14 +27,12 @@ type StateInfoLruCache = Mutex<
>,
>;
type ShortStateInfoResult = Result<
Vec<(
u64, // sstatehash
Arc<HashSet<CompressedStateEvent>>, // full state
Arc<HashSet<CompressedStateEvent>>, // added
Arc<HashSet<CompressedStateEvent>>, // removed
)>,
>;
type ShortStateInfoResult = Vec<(
u64, // sstatehash
Arc<HashSet<CompressedStateEvent>>, // full state
Arc<HashSet<CompressedStateEvent>>, // added
Arc<HashSet<CompressedStateEvent>>, // removed
)>;
type ParentStatesVec = Vec<(
u64, // sstatehash
@ -43,7 +41,7 @@ type ParentStatesVec = Vec<(
Arc<HashSet<CompressedStateEvent>>, // removed
)>;
type HashSetCompressStateEvent = Result<(u64, Arc<HashSet<CompressedStateEvent>>, Arc<HashSet<CompressedStateEvent>>)>;
type HashSetCompressStateEvent = (u64, Arc<HashSet<CompressedStateEvent>>, Arc<HashSet<CompressedStateEvent>>);
pub type CompressedStateEvent = [u8; 2 * size_of::<u64>()];
pub struct Service {
@ -86,12 +84,11 @@ impl crate::Service for Service {
impl Service {
/// Returns a stack with info on shortstatehash, full state, added diff and
/// removed diff for the selected shortstatehash and each parent layer.
#[tracing::instrument(skip(self), level = "debug")]
pub fn load_shortstatehash_info(&self, shortstatehash: u64) -> ShortStateInfoResult {
pub async fn load_shortstatehash_info(&self, shortstatehash: u64) -> Result<ShortStateInfoResult> {
if let Some(r) = self
.stateinfo_cache
.lock()
.unwrap()
.expect("locked")
.get_mut(&shortstatehash)
{
return Ok(r.clone());
@ -101,11 +98,11 @@ impl Service {
parent,
added,
removed,
} = self.db.get_statediff(shortstatehash)?;
} = self.db.get_statediff(shortstatehash).await?;
if let Some(parent) = parent {
let mut response = self.load_shortstatehash_info(parent)?;
let mut state = (*response.last().unwrap().1).clone();
let mut response = Box::pin(self.load_shortstatehash_info(parent)).await?;
let mut state = (*response.last().expect("at least one response").1).clone();
state.extend(added.iter().copied());
let removed = (*removed).clone();
for r in &removed {
@ -116,7 +113,7 @@ impl Service {
self.stateinfo_cache
.lock()
.unwrap()
.expect("locked")
.insert(shortstatehash, response.clone());
Ok(response)
@ -124,33 +121,42 @@ impl Service {
let response = vec![(shortstatehash, added.clone(), added, removed)];
self.stateinfo_cache
.lock()
.unwrap()
.expect("locked")
.insert(shortstatehash, response.clone());
Ok(response)
}
}
pub fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> Result<CompressedStateEvent> {
pub async fn compress_state_event(&self, shortstatekey: u64, event_id: &EventId) -> CompressedStateEvent {
let mut v = shortstatekey.to_be_bytes().to_vec();
v.extend_from_slice(
&self
.services
.short
.get_or_create_shorteventid(event_id)?
.get_or_create_shorteventid(event_id)
.await
.to_be_bytes(),
);
Ok(v.try_into().expect("we checked the size above"))
v.try_into().expect("we checked the size above")
}
/// Returns shortstatekey, event id
#[inline]
pub fn parse_compressed_state_event(&self, compressed_event: &CompressedStateEvent) -> Result<(u64, Arc<EventId>)> {
Ok((
utils::u64_from_bytes(&compressed_event[0..size_of::<u64>()]).expect("bytes have right length"),
self.services.short.get_eventid_from_short(
utils::u64_from_bytes(&compressed_event[size_of::<u64>()..]).expect("bytes have right length"),
)?,
))
pub async fn parse_compressed_state_event(
&self, compressed_event: &CompressedStateEvent,
) -> Result<(u64, Arc<EventId>)> {
use utils::u64_from_u8;
let shortstatekey = u64_from_u8(&compressed_event[0..size_of::<u64>()]);
let event_id = self
.services
.short
.get_eventid_from_short(u64_from_u8(&compressed_event[size_of::<u64>()..]))
.await?;
Ok((shortstatekey, event_id))
}
/// Creates a new shortstatehash that often is just a diff to an already
@ -227,7 +233,7 @@ impl Service {
added: statediffnew,
removed: statediffremoved,
},
)?;
);
return Ok(());
};
@ -280,7 +286,7 @@ impl Service {
added: statediffnew,
removed: statediffremoved,
},
)?;
);
}
Ok(())
@ -288,10 +294,15 @@ impl Service {
/// Returns the new shortstatehash, and the state diff from the previous
/// room state
pub fn save_state(
pub async fn save_state(
&self, room_id: &RoomId, new_state_ids_compressed: Arc<HashSet<CompressedStateEvent>>,
) -> HashSetCompressStateEvent {
let previous_shortstatehash = self.services.state.get_room_shortstatehash(room_id)?;
) -> Result<HashSetCompressStateEvent> {
let previous_shortstatehash = self
.services
.state
.get_room_shortstatehash(room_id)
.await
.ok();
let state_hash = utils::calculate_hash(
&new_state_ids_compressed
@ -303,14 +314,18 @@ impl Service {
let (new_shortstatehash, already_existed) = self
.services
.short
.get_or_create_shortstatehash(&state_hash)?;
.get_or_create_shortstatehash(&state_hash)
.await;
if Some(new_shortstatehash) == previous_shortstatehash {
return Ok((new_shortstatehash, Arc::new(HashSet::new()), Arc::new(HashSet::new())));
}
let states_parents =
previous_shortstatehash.map_or_else(|| Ok(Vec::new()), |p| self.load_shortstatehash_info(p))?;
let states_parents = if let Some(p) = previous_shortstatehash {
self.load_shortstatehash_info(p).await.unwrap_or_default()
} else {
ShortStateInfoResult::new()
};
let (statediffnew, statediffremoved) = if let Some(parent_stateinfo) = states_parents.last() {
let statediffnew: HashSet<_> = new_state_ids_compressed

View file

@ -1,13 +1,18 @@
use std::{mem::size_of, sync::Arc};
use conduit::{checked, utils, Error, PduEvent, Result};
use database::Map;
use conduit::{
checked,
result::LogErr,
utils,
utils::{stream::TryIgnore, ReadyExt},
PduEvent, Result,
};
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{api::client::threads::get_threads::v1::IncludeThreads, OwnedUserId, RoomId, UserId};
use crate::{rooms, Dep};
type PduEventIterResult<'a> = Result<Box<dyn Iterator<Item = Result<(u64, PduEvent)>> + 'a>>;
pub(super) struct Data {
threadid_userids: Arc<Map>,
services: Services,
@ -30,38 +35,37 @@ impl Data {
}
}
pub(super) fn threads_until<'a>(
pub(super) async fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, _include: &'a IncludeThreads,
) -> PduEventIterResult<'a> {
) -> Result<impl Stream<Item = (u64, PduEvent)> + Send + 'a> {
let prefix = self
.services
.short
.get_shortroomid(room_id)?
.expect("room exists")
.get_shortroomid(room_id)
.await?
.to_be_bytes()
.to_vec();
let mut current = prefix.clone();
current.extend_from_slice(&(checked!(until - 1)?).to_be_bytes());
Ok(Box::new(
self.threadid_userids
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pduid, _users)| {
let count = utils::u64_from_bytes(&pduid[(size_of::<u64>())..])
.map_err(|_| Error::bad_database("Invalid pduid in threadid_userids."))?;
let mut pdu = self
.services
.timeline
.get_pdu_from_id(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid reference in threadid_userids"))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
Ok((count, pdu))
}),
))
let stream = self
.threadid_userids
.rev_raw_keys_from(&current)
.ignore_err()
.ready_take_while(move |key| key.starts_with(&prefix))
.map(|pduid| (utils::u64_from_u8(&pduid[(size_of::<u64>())..]), pduid))
.filter_map(move |(count, pduid)| async move {
let mut pdu = self.services.timeline.get_pdu_from_id(pduid).await.ok()?;
if pdu.sender != user_id {
pdu.remove_transaction_id().log_err().ok();
}
Some((count, pdu))
});
Ok(stream)
}
pub(super) fn update_participants(&self, root_id: &[u8], participants: &[OwnedUserId]) -> Result<()> {
@ -71,28 +75,12 @@ impl Data {
.collect::<Vec<_>>()
.join(&[0xFF][..]);
self.threadid_userids.insert(root_id, &users)?;
self.threadid_userids.insert(root_id, &users);
Ok(())
}
pub(super) fn get_participants(&self, root_id: &[u8]) -> Result<Option<Vec<OwnedUserId>>> {
if let Some(users) = self.threadid_userids.get(root_id)? {
Ok(Some(
users
.split(|b| *b == 0xFF)
.map(|bytes| {
UserId::parse(
utils::string_from_bytes(bytes)
.map_err(|_| Error::bad_database("Invalid UserId bytes in threadid_userids."))?,
)
.map_err(|_| Error::bad_database("Invalid UserId in threadid_userids."))
})
.filter_map(Result::ok)
.collect(),
))
} else {
Ok(None)
}
pub(super) async fn get_participants(&self, root_id: &[u8]) -> Result<Vec<OwnedUserId>> {
self.threadid_userids.qry(root_id).await.deserialized()
}
}

View file

@ -2,12 +2,12 @@ mod data;
use std::{collections::BTreeMap, sync::Arc};
use conduit::{Error, PduEvent, Result};
use conduit::{err, PduEvent, Result};
use data::Data;
use futures::Stream;
use ruma::{
api::client::{error::ErrorKind, threads::get_threads::v1::IncludeThreads},
events::relation::BundledThread,
uint, CanonicalJsonValue, EventId, RoomId, UserId,
api::client::threads::get_threads::v1::IncludeThreads, events::relation::BundledThread, uint, CanonicalJsonValue,
EventId, RoomId, UserId,
};
use serde_json::json;
@ -36,30 +36,35 @@ impl crate::Service for Service {
}
impl Service {
pub fn threads_until<'a>(
pub async fn threads_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: u64, include: &'a IncludeThreads,
) -> Result<impl Iterator<Item = Result<(u64, PduEvent)>> + 'a> {
self.db.threads_until(user_id, room_id, until, include)
) -> Result<impl Stream<Item = (u64, PduEvent)> + Send + 'a> {
self.db
.threads_until(user_id, room_id, until, include)
.await
}
pub fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> {
pub async fn add_to_thread(&self, root_event_id: &EventId, pdu: &PduEvent) -> Result<()> {
let root_id = self
.services
.timeline
.get_pdu_id(root_event_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Invalid event id in thread message"))?;
.get_pdu_id(root_event_id)
.await
.map_err(|e| err!(Request(InvalidParam("Invalid event_id in thread message: {e:?}"))))?;
let root_pdu = self
.services
.timeline
.get_pdu_from_id(&root_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?;
.get_pdu_from_id(&root_id)
.await
.map_err(|e| err!(Request(InvalidParam("Thread root not found: {e:?}"))))?;
let mut root_pdu_json = self
.services
.timeline
.get_pdu_json_from_id(&root_id)?
.ok_or_else(|| Error::BadRequest(ErrorKind::InvalidParam, "Thread root pdu not found"))?;
.get_pdu_json_from_id(&root_id)
.await
.map_err(|e| err!(Request(InvalidParam("Thread root pdu not found: {e:?}"))))?;
if let CanonicalJsonValue::Object(unsigned) = root_pdu_json
.entry("unsigned".to_owned())
@ -103,11 +108,12 @@ impl Service {
self.services
.timeline
.replace_pdu(&root_id, &root_pdu_json, &root_pdu)?;
.replace_pdu(&root_id, &root_pdu_json, &root_pdu)
.await?;
}
let mut users = Vec::new();
if let Some(userids) = self.db.get_participants(&root_id)? {
if let Ok(userids) = self.db.get_participants(&root_id).await {
users.extend_from_slice(&userids);
} else {
users.push(root_pdu.sender);

View file

@ -1,12 +1,20 @@
use std::{
collections::{hash_map, HashMap},
mem::size_of,
sync::{Arc, Mutex},
sync::Arc,
};
use conduit::{checked, error, utils, Error, PduCount, PduEvent, Result};
use database::{Database, Map};
use ruma::{api::client::error::ErrorKind, CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use conduit::{
err, expected,
result::{LogErr, NotFound},
utils,
utils::{stream::TryIgnore, u64_from_u8, ReadyExt},
Err, PduCount, PduEvent, Result,
};
use database::{Database, Deserialized, KeyVal, Map};
use futures::{FutureExt, Stream, StreamExt};
use ruma::{CanonicalJsonObject, EventId, OwnedRoomId, OwnedUserId, RoomId, UserId};
use tokio::sync::Mutex;
use crate::{rooms, Dep};
@ -25,8 +33,7 @@ struct Services {
short: Dep<rooms::short::Service>,
}
type PdusIterItem = Result<(PduCount, PduEvent)>;
type PdusIterator<'a> = Box<dyn Iterator<Item = PdusIterItem> + 'a>;
pub type PdusIterItem = (PduCount, PduEvent);
type LastTimelineCountCache = Mutex<HashMap<OwnedRoomId, PduCount>>;
impl Data {
@ -46,23 +53,20 @@ impl Data {
}
}
pub(super) fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
pub(super) async fn last_timeline_count(&self, sender_user: &UserId, room_id: &RoomId) -> Result<PduCount> {
match self
.lasttimelinecount_cache
.lock()
.expect("locked")
.await
.entry(room_id.to_owned())
{
hash_map::Entry::Vacant(v) => {
if let Some(last_count) = self
.pdus_until(sender_user, room_id, PduCount::max())?
.find_map(|r| {
// Filter out buggy events
if r.is_err() {
error!("Bad pdu in pdus_since: {:?}", r);
}
r.ok()
}) {
.pdus_until(sender_user, room_id, PduCount::max())
.await?
.next()
.await
{
Ok(*v.insert(last_count.0))
} else {
Ok(PduCount::Normal(0))
@ -73,232 +77,215 @@ impl Data {
}
/// Returns the `count` of this pdu's id.
pub(super) fn get_pdu_count(&self, event_id: &EventId) -> Result<Option<PduCount>> {
pub(super) async fn get_pdu_count(&self, event_id: &EventId) -> Result<PduCount> {
self.eventid_pduid
.get(event_id.as_bytes())?
.qry(event_id)
.await
.map(|pdu_id| pdu_count(&pdu_id))
.transpose()
}
/// Returns the json of a pdu.
pub(super) fn get_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.get_non_outlier_pdu_json(event_id)?.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)
pub(super) async fn get_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
if let Ok(pdu) = self.get_non_outlier_pdu_json(event_id).await {
return Ok(pdu);
}
self.eventid_outlierpdu
.qry(event_id)
.await
.deserialized_json()
}
/// Returns the json of a pdu.
pub(super) fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<Option<CanonicalJsonObject>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
pub(super) async fn get_non_outlier_pdu_json(&self, event_id: &EventId) -> Result<CanonicalJsonObject> {
let pduid = self.get_pdu_id(event_id).await?;
self.pduid_pdu.qry(&pduid).await.deserialized_json()
}
/// Returns the pdu's id.
#[inline]
pub(super) fn get_pdu_id(&self, event_id: &EventId) -> Result<Option<database::Handle<'_>>> {
self.eventid_pduid.get(event_id.as_bytes())
pub(super) async fn get_pdu_id(&self, event_id: &EventId) -> Result<database::Handle<'_>> {
self.eventid_pduid.qry(event_id).await
}
/// Returns the pdu directly from `eventid_pduid` only.
pub(super) fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<Option<PduEvent>> {
self.eventid_pduid
.get(event_id.as_bytes())?
.map(|pduid| {
self.pduid_pdu
.get(&pduid)?
.ok_or_else(|| Error::bad_database("Invalid pduid in eventid_pduid."))
})
.transpose()?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
pub(super) async fn get_non_outlier_pdu(&self, event_id: &EventId) -> Result<PduEvent> {
let pduid = self.get_pdu_id(event_id).await?;
self.pduid_pdu.qry(&pduid).await.deserialized_json()
}
/// Like get_non_outlier_pdu(), but without the expense of fetching and
/// parsing the PduEvent
pub(super) async fn non_outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> {
let pduid = self.get_pdu_id(event_id).await?;
self.pduid_pdu.qry(&pduid).await?;
Ok(())
}
/// Returns the pdu.
///
/// Checks the `eventid_outlierpdu` Tree if not found in the timeline.
pub(super) fn get_pdu(&self, event_id: &EventId) -> Result<Option<Arc<PduEvent>>> {
if let Some(pdu) = self
.get_non_outlier_pdu(event_id)?
.map_or_else(
|| {
self.eventid_outlierpdu
.get(event_id.as_bytes())?
.map(|pdu| serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db.")))
.transpose()
},
|x| Ok(Some(x)),
)?
.map(Arc::new)
{
Ok(Some(pdu))
} else {
Ok(None)
pub(super) async fn get_pdu(&self, event_id: &EventId) -> Result<Arc<PduEvent>> {
if let Ok(pdu) = self.get_non_outlier_pdu(event_id).await {
return Ok(Arc::new(pdu));
}
self.eventid_outlierpdu
.qry(event_id)
.await
.deserialized_json()
.map(Arc::new)
}
/// Like get_non_outlier_pdu(), but without the expense of fetching and
/// parsing the PduEvent
pub(super) async fn outlier_pdu_exists(&self, event_id: &EventId) -> Result<()> {
self.eventid_outlierpdu.qry(event_id).await?;
Ok(())
}
/// Like get_pdu(), but without the expense of fetching and parsing the data
pub(super) async fn pdu_exists(&self, event_id: &EventId) -> bool {
let non_outlier = self.non_outlier_pdu_exists(event_id).map(|res| res.is_ok());
let outlier = self.outlier_pdu_exists(event_id).map(|res| res.is_ok());
//TODO: parallelize
non_outlier.await || outlier.await
}
/// Returns the pdu.
///
/// This does __NOT__ check the outliers `Tree`.
pub(super) fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<Option<PduEvent>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
pub(super) async fn get_pdu_from_id(&self, pdu_id: &[u8]) -> Result<PduEvent> {
self.pduid_pdu.qry(pdu_id).await.deserialized_json()
}
/// Returns the pdu as a `BTreeMap<String, CanonicalJsonValue>`.
pub(super) fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<Option<CanonicalJsonObject>> {
self.pduid_pdu.get(pdu_id)?.map_or(Ok(None), |pdu| {
Ok(Some(
serde_json::from_slice(&pdu).map_err(|_| Error::bad_database("Invalid PDU in db."))?,
))
})
pub(super) async fn get_pdu_json_from_id(&self, pdu_id: &[u8]) -> Result<CanonicalJsonObject> {
self.pduid_pdu.qry(pdu_id).await.deserialized_json()
}
pub(super) fn append_pdu(
&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64,
) -> Result<()> {
pub(super) async fn append_pdu(&self, pdu_id: &[u8], pdu: &PduEvent, json: &CanonicalJsonObject, count: u64) {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
);
self.lasttimelinecount_cache
.lock()
.expect("locked")
.await
.insert(pdu.room_id.clone(), PduCount::Normal(count));
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes())?;
Ok(())
self.eventid_pduid.insert(pdu.event_id.as_bytes(), pdu_id);
self.eventid_outlierpdu.remove(pdu.event_id.as_bytes());
}
pub(super) fn prepend_backfill_pdu(
&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject,
) -> Result<()> {
pub(super) fn prepend_backfill_pdu(&self, pdu_id: &[u8], event_id: &EventId, json: &CanonicalJsonObject) {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(json).expect("CanonicalJsonObject is always a valid"),
)?;
);
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id)?;
self.eventid_outlierpdu.remove(event_id.as_bytes())?;
Ok(())
self.eventid_pduid.insert(event_id.as_bytes(), pdu_id);
self.eventid_outlierpdu.remove(event_id.as_bytes());
}
/// Removes a pdu and creates a new one with the same id.
pub(super) fn replace_pdu(&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent) -> Result<()> {
if self.pduid_pdu.get(pdu_id)?.is_some() {
self.pduid_pdu.insert(
pdu_id,
&serde_json::to_vec(pdu_json).expect("CanonicalJsonObject is always a valid"),
)?;
} else {
return Err(Error::BadRequest(ErrorKind::NotFound, "PDU does not exist."));
pub(super) async fn replace_pdu(
&self, pdu_id: &[u8], pdu_json: &CanonicalJsonObject, _pdu: &PduEvent,
) -> Result<()> {
if self.pduid_pdu.qry(pdu_id).await.is_not_found() {
return Err!(Request(NotFound("PDU does not exist.")));
}
let pdu = serde_json::to_vec(pdu_json)?;
self.pduid_pdu.insert(pdu_id, &pdu);
Ok(())
}
/// Returns an iterator over all events and their tokens in a room that
/// happened before the event with id `until` in reverse-chronological
/// order.
pub(super) fn pdus_until(&self, user_id: &UserId, room_id: &RoomId, until: PduCount) -> Result<PdusIterator<'_>> {
let (prefix, current) = self.count_to_id(room_id, until, 1, true)?;
pub(super) async fn pdus_until<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, until: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
let (prefix, current) = self.count_to_id(room_id, until, 1, true).await?;
let stream = self
.pduid_pdu
.rev_raw_stream_from(&current)
.ignore_err()
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
.map(move |item| Self::each_pdu(item, user_id));
let user_id = user_id.to_owned();
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, true)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
Ok(stream)
}
pub(super) fn pdus_after(&self, user_id: &UserId, room_id: &RoomId, from: PduCount) -> Result<PdusIterator<'_>> {
let (prefix, current) = self.count_to_id(room_id, from, 1, false)?;
pub(super) async fn pdus_after<'a>(
&'a self, user_id: &'a UserId, room_id: &'a RoomId, from: PduCount,
) -> Result<impl Stream<Item = PdusIterItem> + Send + 'a> {
let (prefix, current) = self.count_to_id(room_id, from, 1, false).await?;
let stream = self
.pduid_pdu
.raw_stream_from(&current)
.ignore_err()
.ready_take_while(move |(key, _)| key.starts_with(&prefix))
.map(move |item| Self::each_pdu(item, user_id));
let user_id = user_id.to_owned();
Ok(stream)
}
Ok(Box::new(
self.pduid_pdu
.iter_from(&current, false)
.take_while(move |(k, _)| k.starts_with(&prefix))
.map(move |(pdu_id, v)| {
let mut pdu = serde_json::from_slice::<PduEvent>(&v)
.map_err(|_| Error::bad_database("PDU in db is invalid."))?;
if pdu.sender != user_id {
pdu.remove_transaction_id()?;
}
pdu.add_age()?;
let count = pdu_count(&pdu_id)?;
Ok((count, pdu))
}),
))
fn each_pdu((pdu_id, pdu): KeyVal<'_>, user_id: &UserId) -> PdusIterItem {
let mut pdu =
serde_json::from_slice::<PduEvent>(pdu).expect("PduEvent in pduid_pdu database column is invalid JSON");
if pdu.sender != user_id {
pdu.remove_transaction_id().log_err().ok();
}
pdu.add_age().log_err().ok();
let count = pdu_count(pdu_id);
(count, pdu)
}
pub(super) fn increment_notification_counts(
&self, room_id: &RoomId, notifies: Vec<OwnedUserId>, highlights: Vec<OwnedUserId>,
) -> Result<()> {
let mut notifies_batch = Vec::new();
let mut highlights_batch = Vec::new();
) {
let _cork = self.db.cork();
for user in notifies {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
notifies_batch.push(userroom_id);
increment(&self.userroomid_notificationcount, &userroom_id);
}
for user in highlights {
let mut userroom_id = user.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
highlights_batch.push(userroom_id);
increment(&self.userroomid_highlightcount, &userroom_id);
}
self.userroomid_notificationcount
.increment_batch(notifies_batch.iter().map(Vec::as_slice))?;
self.userroomid_highlightcount
.increment_batch(highlights_batch.iter().map(Vec::as_slice))?;
Ok(())
}
pub(super) fn count_to_id(
pub(super) async fn count_to_id(
&self, room_id: &RoomId, count: PduCount, offset: u64, subtract: bool,
) -> Result<(Vec<u8>, Vec<u8>)> {
let prefix = self
.services
.short
.get_shortroomid(room_id)?
.ok_or_else(|| Error::bad_database("Looked for bad shortroomid in timeline"))?
.get_shortroomid(room_id)
.await
.map_err(|e| err!(Request(NotFound("Room {room_id:?} not found: {e:?}"))))?
.to_be_bytes()
.to_vec();
let mut pdu_id = prefix.clone();
// +1 so we don't send the base event
let count_raw = match count {
@ -326,17 +313,23 @@ impl Data {
}
/// Returns the `count` of this pdu's id.
pub(super) fn pdu_count(pdu_id: &[u8]) -> Result<PduCount> {
let stride = size_of::<u64>();
let pdu_id_len = pdu_id.len();
let last_u64 = utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - stride)?..])
.map_err(|_| Error::bad_database("PDU has invalid count bytes."))?;
let second_last_u64 =
utils::u64_from_bytes(&pdu_id[checked!(pdu_id_len - 2 * stride)?..checked!(pdu_id_len - stride)?]);
pub(super) fn pdu_count(pdu_id: &[u8]) -> PduCount {
const STRIDE: usize = size_of::<u64>();
if matches!(second_last_u64, Ok(0)) {
Ok(PduCount::Backfilled(u64::MAX.saturating_sub(last_u64)))
let pdu_id_len = pdu_id.len();
let last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - STRIDE)..]);
let second_last_u64 = u64_from_u8(&pdu_id[expected!(pdu_id_len - 2 * STRIDE)..expected!(pdu_id_len - STRIDE)]);
if second_last_u64 == 0 {
PduCount::Backfilled(u64::MAX.saturating_sub(last_u64))
} else {
Ok(PduCount::Normal(last_u64))
PduCount::Normal(last_u64)
}
}
//TODO: this is an ABA
fn increment(db: &Arc<Map>, key: &[u8]) {
let old = db.get(key);
let new = utils::increment(old.ok().as_deref());
db.insert(key, &new);
}

File diff suppressed because it is too large Load diff

View file

@ -46,7 +46,7 @@ impl Service {
/// Sets a user as typing until the timeout timestamp is reached or
/// roomtyping_remove is called.
pub async fn typing_add(&self, user_id: &UserId, room_id: &RoomId, timeout: u64) -> Result<()> {
debug_info!("typing started {:?} in {:?} timeout:{:?}", user_id, room_id, timeout);
debug_info!("typing started {user_id:?} in {room_id:?} timeout:{timeout:?}");
// update clients
self.typing
.write()
@ -54,17 +54,19 @@ impl Service {
.entry(room_id.to_owned())
.or_default()
.insert(user_id.to_owned(), timeout);
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
}
// update federation
if self.services.globals.user_is_local(user_id) {
self.federation_send(room_id, user_id, true)?;
self.federation_send(room_id, user_id, true).await?;
}
Ok(())
@ -72,7 +74,7 @@ impl Service {
/// Removes a user from typing before the timeout is reached.
pub async fn typing_remove(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
debug_info!("typing stopped {:?} in {:?}", user_id, room_id);
debug_info!("typing stopped {user_id:?} in {room_id:?}");
// update clients
self.typing
.write()
@ -80,31 +82,31 @@ impl Service {
.entry(room_id.to_owned())
.or_default()
.remove(user_id);
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
}
// update federation
if self.services.globals.user_is_local(user_id) {
self.federation_send(room_id, user_id, false)?;
self.federation_send(room_id, user_id, false).await?;
}
Ok(())
}
pub async fn wait_for_update(&self, room_id: &RoomId) -> Result<()> {
pub async fn wait_for_update(&self, room_id: &RoomId) {
let mut receiver = self.typing_update_sender.subscribe();
while let Ok(next) = receiver.recv().await {
if next == room_id {
break;
}
}
Ok(())
}
/// Makes sure that typing events with old timestamps get removed.
@ -123,30 +125,30 @@ impl Service {
removable.push(user.clone());
}
}
drop(typing);
};
if !removable.is_empty() {
let typing = &mut self.typing.write().await;
let room = typing.entry(room_id.to_owned()).or_default();
for user in &removable {
debug_info!("typing timeout {:?} in {:?}", &user, room_id);
debug_info!("typing timeout {user:?} in {room_id:?}");
room.remove(user);
}
// update clients
self.last_typing_update
.write()
.await
.insert(room_id.to_owned(), self.services.globals.next_count()?);
if self.typing_update_sender.send(room_id.to_owned()).is_err() {
trace!("receiver found what it was looking for and is no longer interested");
}
// update federation
for user in removable {
if self.services.globals.user_is_local(&user) {
self.federation_send(room_id, &user, false)?;
for user in &removable {
if self.services.globals.user_is_local(user) {
self.federation_send(room_id, user, false).await?;
}
}
}
@ -183,7 +185,7 @@ impl Service {
})
}
fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> {
async fn federation_send(&self, room_id: &RoomId, user_id: &UserId, typing: bool) -> Result<()> {
debug_assert!(
self.services.globals.user_is_local(user_id),
"tried to broadcast typing status of remote user",
@ -197,7 +199,8 @@ impl Service {
self.services
.sending
.send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))?;
.send_edu_room(room_id, serde_json::to_vec(&edu).expect("Serialized Edu::Typing"))
.await?;
Ok(())
}

View file

@ -1,8 +1,9 @@
use std::sync::Arc;
use conduit::{utils, Error, Result};
use database::Map;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use conduit::Result;
use database::{Deserialized, Map};
use futures::{Stream, StreamExt};
use ruma::{RoomId, UserId};
use crate::{globals, rooms, Dep};
@ -11,13 +12,13 @@ pub(super) struct Data {
userroomid_highlightcount: Arc<Map>,
roomuserid_lastnotificationread: Arc<Map>,
roomsynctoken_shortstatehash: Arc<Map>,
userroomid_joined: Arc<Map>,
services: Services,
}
struct Services {
globals: Dep<globals::Service>,
short: Dep<rooms::short::Service>,
state_cache: Dep<rooms::state_cache::Service>,
}
impl Data {
@ -28,15 +29,15 @@ impl Data {
userroomid_highlightcount: db["userroomid_highlightcount"].clone(),
roomuserid_lastnotificationread: db["userroomid_highlightcount"].clone(), //< NOTE: known bug from conduit
roomsynctoken_shortstatehash: db["roomsynctoken_shortstatehash"].clone(),
userroomid_joined: db["userroomid_joined"].clone(),
services: Services {
globals: args.depend::<globals::Service>("globals"),
short: args.depend::<rooms::short::Service>("rooms::short"),
state_cache: args.depend::<rooms::state_cache::Service>("rooms::state_cache"),
},
}
}
pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
pub(super) fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
@ -45,128 +46,73 @@ impl Data {
roomuser_id.extend_from_slice(user_id.as_bytes());
self.userroomid_notificationcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
.insert(&userroom_id, &0_u64.to_be_bytes());
self.userroomid_highlightcount
.insert(&userroom_id, &0_u64.to_be_bytes())?;
.insert(&userroom_id, &0_u64.to_be_bytes());
self.roomuserid_lastnotificationread
.insert(&roomuser_id, &self.services.globals.next_count()?.to_be_bytes())?;
Ok(())
.insert(&roomuser_id, &self.services.globals.next_count().unwrap().to_be_bytes());
}
pub(super) fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
pub(super) async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
let key = (user_id, room_id);
self.userroomid_notificationcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid notification count in db."))
})
.qry(&key)
.await
.deserialized()
.unwrap_or(0)
}
pub(super) fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut userroom_id = user_id.as_bytes().to_vec();
userroom_id.push(0xFF);
userroom_id.extend_from_slice(room_id.as_bytes());
pub(super) async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
let key = (user_id, room_id);
self.userroomid_highlightcount
.get(&userroom_id)?
.map_or(Ok(0), |bytes| {
utils::u64_from_bytes(&bytes).map_err(|_| Error::bad_database("Invalid highlight count in db."))
})
.qry(&key)
.await
.deserialized()
.unwrap_or(0)
}
pub(super) fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
let mut key = room_id.as_bytes().to_vec();
key.push(0xFF);
key.extend_from_slice(user_id.as_bytes());
Ok(self
.roomuserid_lastnotificationread
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Count in roomuserid_lastprivatereadupdate is invalid."))
})
.transpose()?
.unwrap_or(0))
pub(super) async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
let key = (room_id, user_id);
self.roomuserid_lastnotificationread
.qry(&key)
.await
.deserialized()
.unwrap_or(0)
}
pub(super) fn associate_token_shortstatehash(
&self, room_id: &RoomId, token: u64, shortstatehash: u64,
) -> Result<()> {
pub(super) async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) {
let shortroomid = self
.services
.short
.get_shortroomid(room_id)?
.get_shortroomid(room_id)
.await
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
self.roomsynctoken_shortstatehash
.insert(&key, &shortstatehash.to_be_bytes())
.insert(&key, &shortstatehash.to_be_bytes());
}
pub(super) fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
let shortroomid = self
.services
.short
.get_shortroomid(room_id)?
.expect("room exists");
let mut key = shortroomid.to_be_bytes().to_vec();
key.extend_from_slice(&token.to_be_bytes());
pub(super) async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<u64> {
let shortroomid = self.services.short.get_shortroomid(room_id).await?;
let key: &[u64] = &[shortroomid, token];
self.roomsynctoken_shortstatehash
.get(&key)?
.map(|bytes| {
utils::u64_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid shortstatehash in roomsynctoken_shortstatehash"))
})
.transpose()
.qry(key)
.await
.deserialized()
}
//TODO: optimize; replace point-queries with dual iteration
pub(super) fn get_shared_rooms<'a>(
&'a self, users: Vec<OwnedUserId>,
) -> Result<Box<dyn Iterator<Item = Result<OwnedRoomId>> + 'a>> {
let iterators = users.into_iter().map(move |user_id| {
let mut prefix = user_id.as_bytes().to_vec();
prefix.push(0xFF);
self.userroomid_joined
.scan_prefix(prefix)
.map(|(key, _)| {
let roomid_index = key
.iter()
.enumerate()
.find(|(_, &b)| b == 0xFF)
.ok_or_else(|| Error::bad_database("Invalid userroomid_joined in db."))?
.0
.saturating_add(1); // +1 because the room id starts AFTER the separator
let room_id = key[roomid_index..].to_vec();
Ok::<_, Error>(room_id)
})
.filter_map(Result::ok)
});
// We use the default compare function because keys are sorted correctly (not
// reversed)
Ok(Box::new(
utils::common_elements(iterators, Ord::cmp)
.expect("users is not empty")
.map(|bytes| {
RoomId::parse(
utils::string_from_bytes(&bytes)
.map_err(|_| Error::bad_database("Invalid RoomId bytes in userroomid_joined"))?,
)
.map_err(|_| Error::bad_database("Invalid RoomId in userroomid_joined."))
}),
))
&'a self, user_a: &'a UserId, user_b: &'a UserId,
) -> impl Stream<Item = &RoomId> + Send + 'a {
self.services
.state_cache
.rooms_joined(user_a)
.filter(|room_id| self.services.state_cache.is_joined(user_b, room_id))
}
}

View file

@ -3,7 +3,8 @@ mod data;
use std::sync::Arc;
use conduit::Result;
use ruma::{OwnedRoomId, OwnedUserId, RoomId, UserId};
use futures::{pin_mut, Stream, StreamExt};
use ruma::{RoomId, UserId};
use self::data::Data;
@ -22,32 +23,49 @@ impl crate::Service for Service {
}
impl Service {
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) -> Result<()> {
self.db.reset_notification_counts(user_id, room_id)
#[inline]
pub fn reset_notification_counts(&self, user_id: &UserId, room_id: &RoomId) {
self.db.reset_notification_counts(user_id, room_id);
}
pub fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
self.db.notification_count(user_id, room_id)
#[inline]
pub async fn notification_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
self.db.notification_count(user_id, room_id).await
}
pub fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
self.db.highlight_count(user_id, room_id)
#[inline]
pub async fn highlight_count(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
self.db.highlight_count(user_id, room_id).await
}
pub fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> Result<u64> {
self.db.last_notification_read(user_id, room_id)
#[inline]
pub async fn last_notification_read(&self, user_id: &UserId, room_id: &RoomId) -> u64 {
self.db.last_notification_read(user_id, room_id).await
}
pub fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) -> Result<()> {
#[inline]
pub async fn associate_token_shortstatehash(&self, room_id: &RoomId, token: u64, shortstatehash: u64) {
self.db
.associate_token_shortstatehash(room_id, token, shortstatehash)
.await;
}
pub fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<Option<u64>> {
self.db.get_token_shortstatehash(room_id, token)
#[inline]
pub async fn get_token_shortstatehash(&self, room_id: &RoomId, token: u64) -> Result<u64> {
self.db.get_token_shortstatehash(room_id, token).await
}
pub fn get_shared_rooms(&self, users: Vec<OwnedUserId>) -> Result<impl Iterator<Item = Result<OwnedRoomId>> + '_> {
self.db.get_shared_rooms(users)
#[inline]
pub fn get_shared_rooms<'a>(
&'a self, user_a: &'a UserId, user_b: &'a UserId,
) -> impl Stream<Item = &RoomId> + Send + 'a {
self.db.get_shared_rooms(user_a, user_b)
}
pub async fn has_shared_rooms<'a>(&'a self, user_a: &'a UserId, user_b: &'a UserId) -> bool {
let get_shared_rooms = self.get_shared_rooms(user_a, user_b);
pin_mut!(get_shared_rooms);
get_shared_rooms.next().await.is_some()
}
}